Spaces:
Running
on
Zero
Running
on
Zero
jhaozhuang
commited on
Commit
·
77771e4
1
Parent(s):
e826d2f
app
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- BidirectionalTranslation/LICENSE +26 -0
- BidirectionalTranslation/README.md +72 -0
- BidirectionalTranslation/data/__init__.py +100 -0
- BidirectionalTranslation/data/aligned_dataset.py +60 -0
- BidirectionalTranslation/data/base_dataset.py +164 -0
- BidirectionalTranslation/data/image_folder.py +66 -0
- BidirectionalTranslation/data/singleCo_dataset.py +85 -0
- BidirectionalTranslation/data/singleSr_dataset.py +73 -0
- BidirectionalTranslation/models/__init__.py +61 -0
- BidirectionalTranslation/models/base_model.py +277 -0
- BidirectionalTranslation/models/cycle_ganstft_model.py +103 -0
- BidirectionalTranslation/models/networks.py +1375 -0
- BidirectionalTranslation/options/base_options.py +142 -0
- BidirectionalTranslation/options/test_options.py +19 -0
- BidirectionalTranslation/requirements.txt +8 -0
- BidirectionalTranslation/scripts/test_western2manga.sh +49 -0
- BidirectionalTranslation/test.py +71 -0
- BidirectionalTranslation/util/html.py +86 -0
- BidirectionalTranslation/util/util.py +136 -0
- BidirectionalTranslation/util/visualizer.py +221 -0
- app.py +507 -0
- assets/example_0/input.jpg +0 -0
- assets/example_0/ref1.jpg +0 -0
- assets/example_1/input.jpg +0 -0
- assets/example_1/ref1.jpg +0 -0
- assets/example_1/ref2.jpg +0 -0
- assets/example_1/ref3.jpg +0 -0
- assets/example_2/input.png +0 -0
- assets/example_2/ref1.png +0 -0
- assets/example_2/ref2.png +0 -0
- assets/example_2/ref3.png +0 -0
- assets/example_3/input.png +0 -0
- assets/example_3/ref1.png +0 -0
- assets/example_3/ref2.png +0 -0
- assets/example_3/ref3.png +0 -0
- assets/example_4/input.jpg +0 -0
- assets/example_4/ref1.jpg +0 -0
- assets/example_4/ref2.jpg +0 -0
- assets/example_4/ref3.jpg +0 -0
- assets/example_5/input.png +0 -0
- assets/example_5/ref1.png +0 -0
- assets/example_5/ref2.png +0 -0
- assets/example_5/ref3.png +0 -0
- assets/mask.png +0 -0
- diffusers/.github/ISSUE_TEMPLATE/bug-report.yml +110 -0
- diffusers/.github/ISSUE_TEMPLATE/config.yml +4 -0
- diffusers/.github/ISSUE_TEMPLATE/feature_request.md +20 -0
- diffusers/.github/ISSUE_TEMPLATE/feedback.md +12 -0
- diffusers/.github/ISSUE_TEMPLATE/new-model-addition.yml +31 -0
- diffusers/.github/ISSUE_TEMPLATE/translate.md +29 -0
BidirectionalTranslation/LICENSE
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Manga Filling Style Conversion with Screentone Variational Autoencoder
|
2 |
+
|
3 |
+
Copyright (c) 2020 The Chinese University of Hong Kong
|
4 |
+
|
5 |
+
Copyright and License Information: The source code, the binary executable, and all data files (hereafter, Software) are copyrighted by The Chinese University of Hong Kong and Tien-Tsin Wong (hereafter, Author), Copyright (c) 2021 The Chinese University of Hong Kong. All Rights Reserved.
|
6 |
+
|
7 |
+
The Author grants to you ("Licensee") a non-exclusive license to use the Software for academic, research and commercial purposes, without fee. For commercial use, Licensee should submit a WRITTEN NOTICE to the Author. The notice should clearly identify the software package/system/hardware (name, version, and/or model number) using the Software. Licensee may distribute the Software to third parties provided that the copyright notice and this statement appears on all copies. Licensee agrees that the copyright notice and this statement will appear on all copies of the Software, or portions thereof. The Author retains exclusive ownership of the Software.
|
8 |
+
|
9 |
+
Licensee may make derivatives of the Software, provided that such derivatives can only be used for the purposes specified in the license grant above.
|
10 |
+
|
11 |
+
THE AUTHOR MAKES NO REPRESENTATIONS OR WARRANTIES ABOUT THE SUITABILITY OF THE SOFTWARE, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, OR NON-INFRINGEMENT. THE AUTHOR SHALL NOT BE LIABLE FOR ANY DAMAGES SUFFERED BY LICENSEE AS A RESULT OF USING, MODIFYING OR DISTRIBUTING THE SOFTWARE OR ITS DERIVATIVES.
|
12 |
+
|
13 |
+
By using the source code, Licensee agrees to cite the following papers in
|
14 |
+
Licensee's publication/work:
|
15 |
+
|
16 |
+
Minshan Xie, Chengze Li, Xueting Liu and Tien-Tsin Wong
|
17 |
+
"Manga Filling Style Conversion with Screentone Variational Autoencoder"
|
18 |
+
ACM Transactions on Graphics (SIGGRAPH Asia 2020 issue), Vol. 39, No. 6, December 2020, pp. 226:1-226:15.
|
19 |
+
|
20 |
+
|
21 |
+
By using or copying the Software, Licensee agrees to abide by the intellectual property laws, and all other applicable laws of the U.S., and the terms of this license.
|
22 |
+
|
23 |
+
Author shall have the right to terminate this license immediately by written notice upon Licensee's breach of, or non-compliance with, any of its terms.
|
24 |
+
Licensee may be held legally responsible for any infringement that is caused or encouraged by Licensee's failure to abide by the terms of this license.
|
25 |
+
|
26 |
+
For more information or comments, send mail to: [email protected]
|
BidirectionalTranslation/README.md
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Bidirectional Translation
|
2 |
+
|
3 |
+
Pytorch implementation for multimodal comic-to-manga translation.
|
4 |
+
|
5 |
+
**Note**: The current software works well with PyTorch 1.6.0+.
|
6 |
+
|
7 |
+
## Prerequisites
|
8 |
+
- Linux
|
9 |
+
- Python 3
|
10 |
+
- CPU or NVIDIA GPU + CUDA CuDNN
|
11 |
+
|
12 |
+
## Getting Started ###
|
13 |
+
### Installation
|
14 |
+
- Clone this repo:
|
15 |
+
```bash
|
16 |
+
git clone https://github.com/msxie/ScreenStyle.git
|
17 |
+
cd ScreenStyle/MangaScreening
|
18 |
+
```
|
19 |
+
- Install PyTorch and dependencies from http://pytorch.org
|
20 |
+
- Install python libraries [tensorboardX](https://github.com/lanpa/tensorboardX)
|
21 |
+
- Install other libraries
|
22 |
+
For pip users:
|
23 |
+
```
|
24 |
+
pip install -r requirements.txt
|
25 |
+
```
|
26 |
+
|
27 |
+
## Data praperation
|
28 |
+
The training requires paired data (including manga image, western image and their line drawings).
|
29 |
+
The line drawing can be extracted using [MangaLineExtraction](https://github.com/ljsabc/MangaLineExtraction).
|
30 |
+
|
31 |
+
```
|
32 |
+
${DATASET}
|
33 |
+
|-- color2manga
|
34 |
+
| |-- val
|
35 |
+
| | |-- ${FOLDER}
|
36 |
+
| | | |-- imgs
|
37 |
+
| | | | |-- 0001.png
|
38 |
+
| | | | |-- ...
|
39 |
+
| | | |-- line
|
40 |
+
| | | | |-- 0001.png
|
41 |
+
| | | | |-- ...
|
42 |
+
```
|
43 |
+
|
44 |
+
### Use a Pre-trained Model
|
45 |
+
- Download the pre-trained [ScreenVAE](https://drive.google.com/file/d/1OBxWHjijMwi9gfTOfDiFiHRZA_CXNSWr/view?usp=sharing) model and place under `checkpoints/ScreenVAE/` folder.
|
46 |
+
|
47 |
+
- Download the pre-trained [color2manga](https://drive.google.com/file/d/18-N1W0t3igWLJWFyplNZ5Fa2YHWASCZY/view?usp=sharing) model and place under `checkpoints/color2manga/` folder.
|
48 |
+
- Generate results with the model
|
49 |
+
```bash
|
50 |
+
bash ./scripts/test_western2manga.sh
|
51 |
+
```
|
52 |
+
|
53 |
+
## Copyright and License
|
54 |
+
You are granted with the [LICENSE](LICENSE) for both academic and commercial usages.
|
55 |
+
|
56 |
+
## Citation
|
57 |
+
If you find the code helpful in your resarch or work, please cite the following papers.
|
58 |
+
```
|
59 |
+
@article{xie-2020-manga,
|
60 |
+
author = {Minshan Xie and Chengze Li and Xueting Liu and Tien-Tsin Wong},
|
61 |
+
title = {Manga Filling Style Conversion with Screentone Variational Autoencoder},
|
62 |
+
journal = {ACM Transactions on Graphics (SIGGRAPH Asia 2020 issue)},
|
63 |
+
month = {December},
|
64 |
+
year = {2020},
|
65 |
+
volume = {39},
|
66 |
+
number = {6},
|
67 |
+
pages = {226:1--226:15}
|
68 |
+
}
|
69 |
+
```
|
70 |
+
|
71 |
+
### Acknowledgements
|
72 |
+
This code borrows heavily from the [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) repository.
|
BidirectionalTranslation/data/__init__.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package includes all the modules related to data loading and preprocessing
|
2 |
+
|
3 |
+
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
|
4 |
+
You need to implement four functions:
|
5 |
+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
6 |
+
-- <__len__>: return the size of dataset.
|
7 |
+
-- <__getitem__>: get a data point from data loader.
|
8 |
+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
9 |
+
|
10 |
+
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
|
11 |
+
See our template dataset class 'template_dataset.py' for more details.
|
12 |
+
"""
|
13 |
+
import importlib
|
14 |
+
import torch.utils.data
|
15 |
+
from data.base_dataset import BaseDataset
|
16 |
+
|
17 |
+
|
18 |
+
def find_dataset_using_name(dataset_name):
|
19 |
+
"""Import the module "data/[dataset_name]_dataset.py".
|
20 |
+
|
21 |
+
In the file, the class called DatasetNameDataset() will
|
22 |
+
be instantiated. It has to be a subclass of BaseDataset,
|
23 |
+
and it is case-insensitive.
|
24 |
+
"""
|
25 |
+
dataset_filename = "data." + dataset_name + "_dataset"
|
26 |
+
datasetlib = importlib.import_module(dataset_filename)
|
27 |
+
|
28 |
+
dataset = None
|
29 |
+
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
|
30 |
+
for name, cls in datasetlib.__dict__.items():
|
31 |
+
if name.lower() == target_dataset_name.lower() \
|
32 |
+
and issubclass(cls, BaseDataset):
|
33 |
+
dataset = cls
|
34 |
+
|
35 |
+
if dataset is None:
|
36 |
+
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
|
37 |
+
|
38 |
+
return dataset
|
39 |
+
|
40 |
+
|
41 |
+
def get_option_setter(dataset_name):
|
42 |
+
"""Return the static method <modify_commandline_options> of the dataset class."""
|
43 |
+
dataset_class = find_dataset_using_name(dataset_name)
|
44 |
+
return dataset_class.modify_commandline_options
|
45 |
+
|
46 |
+
|
47 |
+
def create_dataset(opt):
|
48 |
+
"""Create a dataset given the option.
|
49 |
+
|
50 |
+
This function wraps the class CustomDatasetDataLoader.
|
51 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
52 |
+
|
53 |
+
Example:
|
54 |
+
>>> from data import create_dataset
|
55 |
+
>>> dataset = create_dataset(opt)
|
56 |
+
"""
|
57 |
+
data_loader = CustomDatasetDataLoader(opt)
|
58 |
+
dataset = data_loader.load_data()
|
59 |
+
return dataset
|
60 |
+
|
61 |
+
|
62 |
+
class CustomDatasetDataLoader():
|
63 |
+
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
|
64 |
+
|
65 |
+
def __init__(self, opt):
|
66 |
+
"""Initialize this class
|
67 |
+
|
68 |
+
Step 1: create a dataset instance given the name [dataset_mode]
|
69 |
+
Step 2: create a multi-threaded data loader.
|
70 |
+
"""
|
71 |
+
self.opt = opt
|
72 |
+
dataset_class = find_dataset_using_name(opt.dataset_mode)
|
73 |
+
self.dataset = dataset_class(opt)
|
74 |
+
print("dataset [%s] was created" % type(self.dataset).__name__)
|
75 |
+
|
76 |
+
train_sampler = None
|
77 |
+
if len(opt.gpu_ids) > 1:
|
78 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(self.dataset)
|
79 |
+
|
80 |
+
self.dataloader = torch.utils.data.DataLoader(
|
81 |
+
self.dataset,
|
82 |
+
batch_size=opt.batch_size,
|
83 |
+
#shuffle=not opt.serial_batches,
|
84 |
+
num_workers=int(opt.num_threads),
|
85 |
+
pin_memory=True, sampler=train_sampler
|
86 |
+
)
|
87 |
+
|
88 |
+
def load_data(self):
|
89 |
+
return self
|
90 |
+
|
91 |
+
def __len__(self):
|
92 |
+
"""Return the number of data in the dataset"""
|
93 |
+
return min(len(self.dataset), self.opt.max_dataset_size)
|
94 |
+
|
95 |
+
def __iter__(self):
|
96 |
+
"""Return a batch of data"""
|
97 |
+
for i, data in enumerate(self.dataloader):
|
98 |
+
if i * self.opt.batch_size >= self.opt.max_dataset_size:
|
99 |
+
break
|
100 |
+
yield data
|
BidirectionalTranslation/data/aligned_dataset.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
from data.base_dataset import BaseDataset, get_params, get_transform
|
3 |
+
from data.image_folder import make_dataset
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
|
7 |
+
class AlignedDataset(BaseDataset):
|
8 |
+
"""A dataset class for paired image dataset.
|
9 |
+
|
10 |
+
It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}.
|
11 |
+
During test time, you need to prepare a directory '/path/to/data/test'.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, opt):
|
15 |
+
"""Initialize this dataset class.
|
16 |
+
|
17 |
+
Parameters:
|
18 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
19 |
+
"""
|
20 |
+
BaseDataset.__init__(self, opt)
|
21 |
+
self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory
|
22 |
+
self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths
|
23 |
+
assert(self.opt.load_size >= self.opt.crop_size) # crop_size should be smaller than the size of loaded image
|
24 |
+
self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
|
25 |
+
self.output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc
|
26 |
+
|
27 |
+
def __getitem__(self, index):
|
28 |
+
"""Return a data point and its metadata information.
|
29 |
+
|
30 |
+
Parameters:
|
31 |
+
index - - a random integer for data indexing
|
32 |
+
|
33 |
+
Returns a dictionary that contains A, B, A_paths and B_paths
|
34 |
+
A (tensor) - - an image in the input domain
|
35 |
+
B (tensor) - - its corresponding image in the target domain
|
36 |
+
A_paths (str) - - image paths
|
37 |
+
B_paths (str) - - image paths (same as A_paths)
|
38 |
+
"""
|
39 |
+
# read a image given a random integer index
|
40 |
+
AB_path = self.AB_paths[index%len(self.AB_paths)]
|
41 |
+
AB = Image.open(AB_path).convert('RGB')
|
42 |
+
# split AB image into A and B
|
43 |
+
w, h = AB.size
|
44 |
+
w2 = int(w / 2)
|
45 |
+
A = AB.crop((0, 0, w2, h))
|
46 |
+
B = AB.crop((w2, 0, w, h))
|
47 |
+
|
48 |
+
# apply the same transform to both A and B
|
49 |
+
transform_params = get_params(self.opt, A.size)
|
50 |
+
A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1))
|
51 |
+
B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1))
|
52 |
+
|
53 |
+
A = A_transform(A)
|
54 |
+
B = B_transform(B)
|
55 |
+
|
56 |
+
return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path}
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
"""Return the total number of images in the dataset."""
|
60 |
+
return len(self.AB_paths)*100
|
BidirectionalTranslation/data/base_dataset.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
|
2 |
+
|
3 |
+
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
|
4 |
+
"""
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
import torch.utils.data as data
|
8 |
+
from PIL import Image, ImageOps
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
from abc import ABC, abstractmethod
|
11 |
+
|
12 |
+
|
13 |
+
class BaseDataset(data.Dataset, ABC):
|
14 |
+
"""This class is an abstract base class (ABC) for datasets.
|
15 |
+
|
16 |
+
To create a subclass, you need to implement the following four functions:
|
17 |
+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
18 |
+
-- <__len__>: return the size of dataset.
|
19 |
+
-- <__getitem__>: get a data point.
|
20 |
+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, opt):
|
24 |
+
"""Initialize the class; save the options in the class
|
25 |
+
|
26 |
+
Parameters:
|
27 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
28 |
+
"""
|
29 |
+
self.opt = opt
|
30 |
+
self.root = opt.dataroot
|
31 |
+
|
32 |
+
@staticmethod
|
33 |
+
def modify_commandline_options(parser, is_train):
|
34 |
+
"""Add new dataset-specific options, and rewrite default values for existing options.
|
35 |
+
|
36 |
+
Parameters:
|
37 |
+
parser -- original option parser
|
38 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
the modified parser.
|
42 |
+
"""
|
43 |
+
return parser
|
44 |
+
|
45 |
+
@abstractmethod
|
46 |
+
def __len__(self):
|
47 |
+
"""Return the total number of images in the dataset."""
|
48 |
+
return 0
|
49 |
+
|
50 |
+
@abstractmethod
|
51 |
+
def __getitem__(self, index):
|
52 |
+
"""Return a data point and its metadata information.
|
53 |
+
|
54 |
+
Parameters:
|
55 |
+
index - - a random integer for data indexing
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
|
59 |
+
"""
|
60 |
+
pass
|
61 |
+
|
62 |
+
|
63 |
+
def get_params(opt, size):
|
64 |
+
w, h = size
|
65 |
+
new_h = h
|
66 |
+
new_w = w
|
67 |
+
crop = 0
|
68 |
+
if opt.preprocess == 'resize_and_crop':
|
69 |
+
new_h = new_w = opt.load_size
|
70 |
+
elif opt.preprocess == 'scale_width_and_crop':
|
71 |
+
new_w = opt.load_size
|
72 |
+
new_h = opt.load_size * h // w
|
73 |
+
|
74 |
+
# x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
|
75 |
+
# y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
|
76 |
+
|
77 |
+
x = random.randint(crop, np.maximum(0, new_w - opt.crop_size-crop))
|
78 |
+
y = random.randint(crop, np.maximum(0, new_h - opt.crop_size-crop))
|
79 |
+
|
80 |
+
flip = random.random() > 0.5
|
81 |
+
|
82 |
+
return {'crop_pos': (x, y), 'flip': flip}
|
83 |
+
|
84 |
+
|
85 |
+
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
|
86 |
+
transform_list = []
|
87 |
+
if grayscale:
|
88 |
+
transform_list.append(transforms.Grayscale(1))
|
89 |
+
if 'resize' in opt.preprocess:
|
90 |
+
osize = [opt.load_size, opt.load_size]
|
91 |
+
transform_list.append(transforms.Resize(osize, method))
|
92 |
+
elif 'scale_width' in opt.preprocess:
|
93 |
+
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
|
94 |
+
|
95 |
+
if 'crop' in opt.preprocess:
|
96 |
+
if params is None:
|
97 |
+
# transform_list.append(transforms.RandomCrop(opt.crop_size))
|
98 |
+
transform_list.append(transforms.CenterCrop(opt.crop_size))
|
99 |
+
else:
|
100 |
+
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
|
101 |
+
|
102 |
+
if opt.preprocess == 'none':
|
103 |
+
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=2**8, method=method)))
|
104 |
+
|
105 |
+
if not opt.no_flip:
|
106 |
+
if params is None:
|
107 |
+
transform_list.append(transforms.RandomHorizontalFlip())
|
108 |
+
elif params['flip']:
|
109 |
+
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
|
110 |
+
|
111 |
+
# transform_list += [transforms.ToTensor()]
|
112 |
+
if convert:
|
113 |
+
transform_list += [transforms.ToTensor()]
|
114 |
+
if grayscale:
|
115 |
+
transform_list += [transforms.Normalize((0.5,), (0.5,))]
|
116 |
+
else:
|
117 |
+
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
118 |
+
return transforms.Compose(transform_list)
|
119 |
+
|
120 |
+
|
121 |
+
def __make_power_2(img, base, method=Image.BICUBIC):
|
122 |
+
ow, oh = img.size
|
123 |
+
h = int((oh+base-1) // base * base)
|
124 |
+
w = int((ow+base-1) // base * base)
|
125 |
+
if (h == oh) and (w == ow):
|
126 |
+
return img
|
127 |
+
|
128 |
+
__print_size_warning(ow, oh, w, h)
|
129 |
+
return ImageOps.expand(img, (0, 0, w-ow, h-oh), fill=255)
|
130 |
+
|
131 |
+
|
132 |
+
def __scale_width(img, target_width, method=Image.BICUBIC):
|
133 |
+
ow, oh = img.size
|
134 |
+
if (ow == target_width):
|
135 |
+
return img
|
136 |
+
w = target_width
|
137 |
+
h = int(target_width * oh / ow)
|
138 |
+
return img.resize((w, h), method)
|
139 |
+
|
140 |
+
|
141 |
+
def __crop(img, pos, size):
|
142 |
+
ow, oh = img.size
|
143 |
+
x1, y1 = pos
|
144 |
+
tw = th = size
|
145 |
+
if (ow > tw or oh > th):
|
146 |
+
return img.crop((x1, y1, x1 + tw, y1 + th))
|
147 |
+
return img
|
148 |
+
|
149 |
+
|
150 |
+
def __flip(img, flip):
|
151 |
+
if flip:
|
152 |
+
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
153 |
+
return img
|
154 |
+
|
155 |
+
|
156 |
+
def __print_size_warning(ow, oh, w, h):
|
157 |
+
"""Print warning information about image size(only print once)"""
|
158 |
+
if not hasattr(__print_size_warning, 'has_printed'):
|
159 |
+
print("The image size needs to be a multiple of 4. "
|
160 |
+
"The loaded image size was (%d, %d), so it was adjusted to "
|
161 |
+
"(%d, %d). This adjustment will be done to all images "
|
162 |
+
"whose sizes are not multiples of 4" % (ow, oh, w, h))
|
163 |
+
__print_size_warning.has_printed = True
|
164 |
+
|
BidirectionalTranslation/data/image_folder.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A modified image folder class
|
2 |
+
|
3 |
+
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
|
4 |
+
so that this class can load images from both current directory and its subdirectories.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch.utils.data as data
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
import os
|
11 |
+
import os.path
|
12 |
+
|
13 |
+
IMG_EXTENSIONS = [
|
14 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG', '.npz', 'npy',
|
15 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
16 |
+
]
|
17 |
+
|
18 |
+
|
19 |
+
def is_image_file(filename):
|
20 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
21 |
+
|
22 |
+
|
23 |
+
def make_dataset(dir, max_dataset_size=float("inf")):
|
24 |
+
images = []
|
25 |
+
assert os.path.isdir(dir), '%s is not a valid directory' % dir
|
26 |
+
|
27 |
+
for root, _, fnames in sorted(os.walk(dir)):
|
28 |
+
for fname in fnames:
|
29 |
+
if is_image_file(fname):
|
30 |
+
path = os.path.join(root, fname)
|
31 |
+
images.append(path)
|
32 |
+
return images[:min(max_dataset_size, len(images))]
|
33 |
+
|
34 |
+
|
35 |
+
def default_loader(path):
|
36 |
+
return Image.open(path).convert('RGB')
|
37 |
+
|
38 |
+
|
39 |
+
class ImageFolder(data.Dataset):
|
40 |
+
|
41 |
+
def __init__(self, root, transform=None, return_paths=False,
|
42 |
+
loader=default_loader):
|
43 |
+
imgs = make_dataset(root)
|
44 |
+
if len(imgs) == 0:
|
45 |
+
raise(RuntimeError("Found 0 images in: " + root + "\n"
|
46 |
+
"Supported image extensions are: " +
|
47 |
+
",".join(IMG_EXTENSIONS)))
|
48 |
+
|
49 |
+
self.root = root
|
50 |
+
self.imgs = imgs
|
51 |
+
self.transform = transform
|
52 |
+
self.return_paths = return_paths
|
53 |
+
self.loader = loader
|
54 |
+
|
55 |
+
def __getitem__(self, index):
|
56 |
+
path = self.imgs[index]
|
57 |
+
img = self.loader(path)
|
58 |
+
if self.transform is not None:
|
59 |
+
img = self.transform(img)
|
60 |
+
if self.return_paths:
|
61 |
+
return img, path
|
62 |
+
else:
|
63 |
+
return img
|
64 |
+
|
65 |
+
def __len__(self):
|
66 |
+
return len(self.imgs)
|
BidirectionalTranslation/data/singleCo_dataset.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
from data.base_dataset import BaseDataset, get_params, get_transform
|
3 |
+
from data.image_folder import make_dataset
|
4 |
+
from PIL import Image, ImageEnhance
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import cv2
|
10 |
+
|
11 |
+
|
12 |
+
class SingleCoDataset(BaseDataset):
|
13 |
+
@staticmethod
|
14 |
+
def modify_commandline_options(parser, is_train):
|
15 |
+
return parser
|
16 |
+
|
17 |
+
def __init__(self, opt):
|
18 |
+
self.opt = opt
|
19 |
+
self.root = opt.dataroot
|
20 |
+
self.dir_A = os.path.join(opt.dataroot, opt.phase, opt.folder, 'imgs')
|
21 |
+
|
22 |
+
self.A_paths = make_dataset(self.dir_A)
|
23 |
+
|
24 |
+
self.A_paths = sorted(self.A_paths)
|
25 |
+
|
26 |
+
self.A_size = len(self.A_paths)
|
27 |
+
# self.transform = get_transform(opt)
|
28 |
+
|
29 |
+
def __getitem__(self, index):
|
30 |
+
A_path = self.A_paths[index]
|
31 |
+
|
32 |
+
A_img = Image.open(A_path).convert('RGB')
|
33 |
+
# enhancer = ImageEnhance.Brightness(A_img)
|
34 |
+
# A_img = enhancer.enhance(1.5)
|
35 |
+
if os.path.exists(A_path.replace('imgs','line')[:-4]+'.jpg'):
|
36 |
+
# L_img = Image.open(A_path.replace('imgs','line')[:-4]+'.png')
|
37 |
+
L_img = cv2.imread(A_path.replace('imgs','line')[:-4]+'.jpg')
|
38 |
+
kernel = np.ones((3,3), np.uint8)
|
39 |
+
L_img = cv2.erode(L_img, kernel, iterations=1)
|
40 |
+
L_img = Image.fromarray(L_img)
|
41 |
+
else:
|
42 |
+
L_img = A_img
|
43 |
+
if A_img.size!=L_img.size:
|
44 |
+
# L_img = L_img.resize(A_img.size, Image.ANTIALIAS)
|
45 |
+
A_img = A_img.resize(L_img.size, Image.ANTIALIAS)
|
46 |
+
if A_img.size[1]>2500:
|
47 |
+
A_img = A_img.resize((A_img.size[0]//2, A_img.size[1]//2), Image.ANTIALIAS)
|
48 |
+
|
49 |
+
ow, oh = A_img.size
|
50 |
+
transform_params = get_params(self.opt, A_img.size)
|
51 |
+
A_transform = get_transform(self.opt, transform_params, grayscale=False)
|
52 |
+
L_transform = get_transform(self.opt, transform_params, grayscale=True)
|
53 |
+
A = A_transform(A_img)
|
54 |
+
L = L_transform(L_img)
|
55 |
+
|
56 |
+
# base = 2**9
|
57 |
+
# h = int((oh+base-1) // base * base)
|
58 |
+
# w = int((ow+base-1) // base * base)
|
59 |
+
# A = F.pad(A.unsqueeze(0), (0,w-ow, 0,h-oh), 'replicate').squeeze(0)
|
60 |
+
# L = F.pad(L.unsqueeze(0), (0,w-ow, 0,h-oh), 'replicate').squeeze(0)
|
61 |
+
|
62 |
+
tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
|
63 |
+
Ai = tmp.unsqueeze(0)
|
64 |
+
|
65 |
+
return {'A': A, 'Ai': Ai, 'L': L,
|
66 |
+
'B': torch.zeros(1), 'Bs': torch.zeros(1), 'Bi': torch.zeros(1), 'Bl': torch.zeros(1),
|
67 |
+
'A_paths': A_path, 'h': oh, 'w': ow}
|
68 |
+
|
69 |
+
def __len__(self):
|
70 |
+
return self.A_size
|
71 |
+
|
72 |
+
def name(self):
|
73 |
+
return 'SingleCoDataset'
|
74 |
+
|
75 |
+
|
76 |
+
def M_transform(feat, opt, params=None):
|
77 |
+
outfeat = feat.copy()
|
78 |
+
oh,ow = feat.shape[1:]
|
79 |
+
x1, y1 = params['crop_pos']
|
80 |
+
tw = th = opt.crop_size
|
81 |
+
if (ow > tw or oh > th):
|
82 |
+
outfeat = outfeat[:,y1:y1+th,x1:x1+tw]
|
83 |
+
if params['flip']:
|
84 |
+
outfeat = np.flip(outfeat, 2)#outfeat[:,:,::-1]
|
85 |
+
return torch.from_numpy(outfeat.copy()).float()*2-1.0
|
BidirectionalTranslation/data/singleSr_dataset.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
from data.base_dataset import BaseDataset, get_params, get_transform
|
3 |
+
from data.image_folder import make_dataset
|
4 |
+
from PIL import Image
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
class SingleSrDataset(BaseDataset):
|
12 |
+
@staticmethod
|
13 |
+
def modify_commandline_options(parser, is_train):
|
14 |
+
return parser
|
15 |
+
|
16 |
+
def __init__(self, opt):
|
17 |
+
self.opt = opt
|
18 |
+
self.root = opt.dataroot
|
19 |
+
self.dir_B = os.path.join(opt.dataroot, opt.phase, opt.folder, 'imgs')
|
20 |
+
# self.dir_B = os.path.join(opt.dataroot, opt.phase, 'test/imgs', opt.folder)
|
21 |
+
|
22 |
+
self.B_paths = make_dataset(self.dir_B)
|
23 |
+
|
24 |
+
self.B_paths = sorted(self.B_paths)
|
25 |
+
|
26 |
+
self.B_size = len(self.B_paths)
|
27 |
+
# self.transform = get_transform(opt)
|
28 |
+
# print(self.B_size)
|
29 |
+
|
30 |
+
def __getitem__(self, index):
|
31 |
+
B_path = self.B_paths[index]
|
32 |
+
|
33 |
+
B_img = Image.open(B_path).convert('RGB')
|
34 |
+
if os.path.exists(B_path.replace('imgs','line').replace('.jpg','.png')):
|
35 |
+
L_img = Image.open(B_path.replace('imgs','line').replace('.jpg','.png'))#.convert('RGB')
|
36 |
+
else:
|
37 |
+
L_img = Image.open(B_path.replace('imgs','line').replace('.png','.jpg'))#.convert('RGB')
|
38 |
+
B_img = B_img.resize(L_img.size, Image.ANTIALIAS)
|
39 |
+
|
40 |
+
ow, oh = B_img.size
|
41 |
+
transform_params = get_params(self.opt, B_img.size)
|
42 |
+
B_transform = get_transform(self.opt, transform_params, grayscale=True)
|
43 |
+
B = B_transform(B_img)
|
44 |
+
L = B_transform(L_img)
|
45 |
+
|
46 |
+
# base = 2**8
|
47 |
+
# h = int((oh+base-1) // base * base)
|
48 |
+
# w = int((ow+base-1) // base * base)
|
49 |
+
# B = F.pad(B.unsqueeze(0), (0,w-ow, 0,h-oh), 'replicate').squeeze(0)
|
50 |
+
# L = F.pad(L.unsqueeze(0), (0,w-ow, 0,h-oh), 'replicate').squeeze(0)
|
51 |
+
|
52 |
+
return {'B': B, 'Bs': B, 'Bi': B, 'Bl': L,
|
53 |
+
'A': torch.zeros(1), 'Ai': torch.zeros(1), 'L': torch.zeros(1),
|
54 |
+
'A_paths': B_path, 'h': oh, 'w': ow}
|
55 |
+
|
56 |
+
def __len__(self):
|
57 |
+
return self.B_size
|
58 |
+
|
59 |
+
def name(self):
|
60 |
+
return 'SingleSrDataset'
|
61 |
+
|
62 |
+
|
63 |
+
def M_transform(feat, opt, params=None):
|
64 |
+
outfeat = feat.copy()
|
65 |
+
if params is not None:
|
66 |
+
oh,ow = feat.shape[1:]
|
67 |
+
x1, y1 = params['crop_pos']
|
68 |
+
tw = th = opt.crop_size
|
69 |
+
if (ow > tw or oh > th):
|
70 |
+
outfeat = outfeat[:,y1:y1+th,x1:x1+tw]
|
71 |
+
if params['flip']:
|
72 |
+
outfeat = np.flip(outfeat, 2).copy()#outfeat[:,:,::-1]
|
73 |
+
return torch.from_numpy(outfeat).float()*2-1.0
|
BidirectionalTranslation/models/__init__.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package contains modules related to objective functions, optimizations, and network architectures.
|
2 |
+
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
|
3 |
+
You need to implement the following five functions:
|
4 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
5 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
6 |
+
-- <forward>: produce intermediate results.
|
7 |
+
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
|
8 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
9 |
+
In the function <__init__>, you need to define four lists:
|
10 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
11 |
+
-- self.model_names (str list): specify the images that you want to display and save.
|
12 |
+
-- self.visual_names (str list): define networks used in our training.
|
13 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
|
14 |
+
Now you can use the model class by specifying flag '--model dummy'.
|
15 |
+
See our template model class 'template_model.py' for an example.
|
16 |
+
"""
|
17 |
+
|
18 |
+
import importlib
|
19 |
+
from models.base_model import BaseModel
|
20 |
+
|
21 |
+
|
22 |
+
def find_model_using_name(model_name):
|
23 |
+
"""Import the module "models/[model_name]_model.py".
|
24 |
+
In the file, the class called DatasetNameModel() will
|
25 |
+
be instantiated. It has to be a subclass of BaseModel,
|
26 |
+
and it is case-insensitive.
|
27 |
+
"""
|
28 |
+
model_filename = "models." + model_name + "_model"
|
29 |
+
modellib = importlib.import_module(model_filename)
|
30 |
+
model = None
|
31 |
+
target_model_name = model_name.replace('_', '') + 'model'
|
32 |
+
for name, cls in modellib.__dict__.items():
|
33 |
+
if name.lower() == target_model_name.lower() \
|
34 |
+
and issubclass(cls, BaseModel):
|
35 |
+
model = cls
|
36 |
+
|
37 |
+
if model is None:
|
38 |
+
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
|
39 |
+
exit(0)
|
40 |
+
|
41 |
+
return model
|
42 |
+
|
43 |
+
|
44 |
+
def get_option_setter(model_name):
|
45 |
+
"""Return the static method <modify_commandline_options> of the model class."""
|
46 |
+
model_class = find_model_using_name(model_name)
|
47 |
+
return model_class.modify_commandline_options
|
48 |
+
|
49 |
+
|
50 |
+
def create_model(opt, ckpt_root):
|
51 |
+
"""Create a model given the option.
|
52 |
+
This function warps the class CustomDatasetDataLoader.
|
53 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
54 |
+
Example:
|
55 |
+
>>> from models import create_model
|
56 |
+
>>> model = create_model(opt)
|
57 |
+
"""
|
58 |
+
model = find_model_using_name(opt.model)
|
59 |
+
instance = model(opt, ckpt_root = ckpt_root)
|
60 |
+
print("model [%s] was created" % type(instance).__name__)
|
61 |
+
return instance
|
BidirectionalTranslation/models/base_model.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from collections import OrderedDict
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
from . import networks
|
6 |
+
import numpy as np
|
7 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
8 |
+
|
9 |
+
class BaseModel(ABC):
|
10 |
+
"""This class is an abstract base class (ABC) for models.
|
11 |
+
To create a subclass, you need to implement the following five functions:
|
12 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
13 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
14 |
+
-- <forward>: produce intermediate results.
|
15 |
+
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
|
16 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, opt):
|
20 |
+
"""Initialize the BaseModel class.
|
21 |
+
|
22 |
+
Parameters:
|
23 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
24 |
+
|
25 |
+
When creating your custom class, you need to implement your own initialization.
|
26 |
+
In this fucntion, you should first call `BaseModel.__init__(self, opt)`
|
27 |
+
Then, you need to define four lists:
|
28 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
29 |
+
-- self.model_names (str list): specify the images that you want to display and save.
|
30 |
+
-- self.visual_names (str list): define networks used in our training.
|
31 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
|
32 |
+
"""
|
33 |
+
self.opt = opt
|
34 |
+
self.gpu_ids = opt.gpu_ids
|
35 |
+
self.isTrain = opt.isTrain
|
36 |
+
self.iter = 0
|
37 |
+
self.last_iter = 0
|
38 |
+
self.device = torch.device('cuda:{}'.format(
|
39 |
+
self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
|
40 |
+
# save all the checkpoints to save_dir
|
41 |
+
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
42 |
+
try:
|
43 |
+
os.mkdir(self.save_dir)
|
44 |
+
except:
|
45 |
+
pass
|
46 |
+
# with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
|
47 |
+
if opt.preprocess != 'scale_width':
|
48 |
+
torch.backends.cudnn.benchmark = True
|
49 |
+
self.loss_names = []
|
50 |
+
self.model_names = []
|
51 |
+
self.visual_names = []
|
52 |
+
self.optimizers = []
|
53 |
+
self.image_paths = []
|
54 |
+
|
55 |
+
self.label_colours = np.random.randint(255, size=(100,3))
|
56 |
+
|
57 |
+
def save_suppixel(self,l_inds):
|
58 |
+
im_target_rgb = np.array([self.label_colours[ c % 100 ] for c in l_inds])
|
59 |
+
b,h,w = l_inds.shape
|
60 |
+
im_target_rgb = im_target_rgb.reshape(b,h,w,3).transpose(0,3,1,2)/127.5-1.0
|
61 |
+
return torch.from_numpy(im_target_rgb)
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def modify_commandline_options(parser, is_train):
|
65 |
+
"""Add new model-specific options, and rewrite default values for existing options.
|
66 |
+
|
67 |
+
Parameters:
|
68 |
+
parser -- original option parser
|
69 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
the modified parser.
|
73 |
+
"""
|
74 |
+
return parser
|
75 |
+
|
76 |
+
@abstractmethod
|
77 |
+
def set_input(self, input):
|
78 |
+
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
79 |
+
|
80 |
+
Parameters:
|
81 |
+
input (dict): includes the data itself and its metadata information.
|
82 |
+
"""
|
83 |
+
pass
|
84 |
+
|
85 |
+
@abstractmethod
|
86 |
+
def forward(self):
|
87 |
+
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
88 |
+
pass
|
89 |
+
|
90 |
+
def is_train(self):
|
91 |
+
"""check if the current batch is good for training."""
|
92 |
+
return True
|
93 |
+
|
94 |
+
@abstractmethod
|
95 |
+
def optimize_parameters(self):
|
96 |
+
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
97 |
+
pass
|
98 |
+
|
99 |
+
def setup(self, opt):
|
100 |
+
"""Load and print networks; create schedulers
|
101 |
+
|
102 |
+
Parameters:
|
103 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
104 |
+
"""
|
105 |
+
if self.isTrain:
|
106 |
+
self.schedulers = [networks.get_scheduler(
|
107 |
+
optimizer, opt) for optimizer in self.optimizers]
|
108 |
+
if not self.isTrain or opt.continue_train:
|
109 |
+
self.load_networks(opt.epoch)
|
110 |
+
self.print_networks(opt.verbose)
|
111 |
+
|
112 |
+
def eval(self):
|
113 |
+
"""Make models eval mode during test time"""
|
114 |
+
for name in self.model_names:
|
115 |
+
if isinstance(name, str):
|
116 |
+
net = getattr(self, 'net' + name)
|
117 |
+
net.eval()
|
118 |
+
|
119 |
+
def test(self):
|
120 |
+
"""Forward function used in test time.
|
121 |
+
|
122 |
+
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
|
123 |
+
It also calls <compute_visuals> to produce additional visualization results
|
124 |
+
"""
|
125 |
+
with torch.no_grad():
|
126 |
+
self.forward()
|
127 |
+
self.compute_visuals()
|
128 |
+
|
129 |
+
def compute_visuals(self):
|
130 |
+
"""Calculate additional output images for visdom and HTML visualization"""
|
131 |
+
pass
|
132 |
+
|
133 |
+
def get_image_paths(self):
|
134 |
+
""" Return image paths that are used to load current data"""
|
135 |
+
return self.image_paths
|
136 |
+
|
137 |
+
def update_learning_rate(self):
|
138 |
+
"""Update learning rates for all the networks; called at the end of every epoch"""
|
139 |
+
for scheduler in self.schedulers:
|
140 |
+
scheduler.step()
|
141 |
+
lr = self.optimizers[0].param_groups[0]['lr']
|
142 |
+
print('learning rate = %.7f' % lr)
|
143 |
+
|
144 |
+
def get_current_visuals(self):
|
145 |
+
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
|
146 |
+
visual_ret = OrderedDict()
|
147 |
+
for name in self.visual_names:
|
148 |
+
if isinstance(name, str):
|
149 |
+
if 'Lab' in name:
|
150 |
+
labimg = getattr(self, name).cpu()
|
151 |
+
labimg[:,0,:,:]+=1
|
152 |
+
labimg[:,0,:,:]*=50
|
153 |
+
labimg[:,1:,:,:] *= 110
|
154 |
+
labimg = labimg.permute((0,2,3,1))
|
155 |
+
for i in range(labimg.shape[0]):
|
156 |
+
labimg[i,:,:,:]=lab2rgb(labimg[i,:,:,:])
|
157 |
+
visual_ret[name] = (labimg.permute((0,3,1,2))*2-1.0).to(self.device)
|
158 |
+
elif 'Fm' in name:
|
159 |
+
visual_ret[name] = self.save_suppixel(getattr(self, name).cpu()).to(self.device)
|
160 |
+
else:
|
161 |
+
visual_ret[name] = getattr(self, name)
|
162 |
+
return visual_ret
|
163 |
+
|
164 |
+
def get_current_losses(self):
|
165 |
+
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
|
166 |
+
errors_ret = OrderedDict()
|
167 |
+
for name in self.loss_names:
|
168 |
+
if isinstance(name, str):
|
169 |
+
# float(...) works for both scalar tensor and float number
|
170 |
+
errors_ret[name] = float(getattr(self, 'loss_' + name))
|
171 |
+
return errors_ret
|
172 |
+
|
173 |
+
def save_networks(self, epoch):
|
174 |
+
"""Save all the networks to the disk.
|
175 |
+
|
176 |
+
Parameters:
|
177 |
+
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
178 |
+
"""
|
179 |
+
for name in self.model_names:
|
180 |
+
if isinstance(name, str):
|
181 |
+
save_filename = '%s_net_%s.pth' % (epoch, name)
|
182 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
183 |
+
# print(save_path)
|
184 |
+
net = getattr(self, 'net' + name)
|
185 |
+
|
186 |
+
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
|
187 |
+
torch.save(net.state_dict(), save_path)
|
188 |
+
# net.cuda(self.gpu_ids[0])
|
189 |
+
else:
|
190 |
+
torch.save(net.cpu().state_dict(), save_path)
|
191 |
+
|
192 |
+
save_filename = '%s_net_opt.pth' % (epoch)
|
193 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
194 |
+
save_dict = {'iter': str(self.iter // self.opt.print_freq * self.opt.print_freq)}
|
195 |
+
for i, name in enumerate(self.optimizer_names):
|
196 |
+
save_dict.update({name.lower(): self.optimizers[i].state_dict()})
|
197 |
+
torch.save(save_dict, save_path)
|
198 |
+
|
199 |
+
|
200 |
+
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
|
201 |
+
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
|
202 |
+
key = keys[i]
|
203 |
+
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
|
204 |
+
if module.__class__.__name__.startswith('InstanceNorm') and \
|
205 |
+
(key == 'running_mean' or key == 'running_var'):
|
206 |
+
if getattr(module, key) is None:
|
207 |
+
state_dict.pop('.'.join(keys))
|
208 |
+
if module.__class__.__name__.startswith('InstanceNorm') and \
|
209 |
+
(key == 'num_batches_tracked'):
|
210 |
+
state_dict.pop('.'.join(keys))
|
211 |
+
else:
|
212 |
+
self.__patch_instance_norm_state_dict(
|
213 |
+
state_dict, getattr(module, key), keys, i + 1)
|
214 |
+
|
215 |
+
def load_networks(self, epoch):
|
216 |
+
"""Load all the networks from the disk.
|
217 |
+
|
218 |
+
Parameters:
|
219 |
+
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
220 |
+
"""
|
221 |
+
for name in self.model_names:
|
222 |
+
if isinstance(name, str):
|
223 |
+
load_filename = '%s_net_%s.pth' % (epoch, name)
|
224 |
+
load_path = os.path.join(self.save_dir, load_filename)
|
225 |
+
net = getattr(self, 'net' + name)
|
226 |
+
# if isinstance(net, torch.nn.DataParallel):
|
227 |
+
if isinstance(net, DDP):
|
228 |
+
net = net.module
|
229 |
+
# print(net)
|
230 |
+
print('loading the model from %s' % load_path)
|
231 |
+
# if you are using PyTorch newer than 0.4 (e.g., built from
|
232 |
+
# GitHub source), you can remove str() on self.device
|
233 |
+
state_dict = torch.load(
|
234 |
+
load_path, map_location=lambda storage, loc: storage.cuda())
|
235 |
+
if hasattr(state_dict, '_metadata'):
|
236 |
+
del state_dict._metadata
|
237 |
+
|
238 |
+
# patch InstanceNorm checkpoints prior to 0.4
|
239 |
+
# need to copy keys here because we mutate in loop
|
240 |
+
#for key in list(state_dict.keys()):
|
241 |
+
# self.__patch_instance_norm_state_dict(
|
242 |
+
# state_dict, net, key.split('.'))
|
243 |
+
|
244 |
+
net.load_state_dict(state_dict)
|
245 |
+
del state_dict
|
246 |
+
|
247 |
+
def print_networks(self, verbose):
|
248 |
+
"""Print the total number of parameters in the network and (if verbose) network architecture
|
249 |
+
|
250 |
+
Parameters:
|
251 |
+
verbose (bool) -- if verbose: print the network architecture
|
252 |
+
"""
|
253 |
+
print('---------- Networks initialized -------------')
|
254 |
+
for name in self.model_names:
|
255 |
+
if isinstance(name, str):
|
256 |
+
net = getattr(self, 'net' + name)
|
257 |
+
num_params = 0
|
258 |
+
for param in net.parameters():
|
259 |
+
num_params += param.numel()
|
260 |
+
if verbose:
|
261 |
+
print(net)
|
262 |
+
print('[Network %s] Total number of parameters : %.3f M' %
|
263 |
+
(name, num_params / 1e6))
|
264 |
+
print('-----------------------------------------------')
|
265 |
+
|
266 |
+
def set_requires_grad(self, nets, requires_grad=False):
|
267 |
+
"""Set requires_grad=False for all the networks to avoid unnecessary computations
|
268 |
+
Parameters:
|
269 |
+
nets (network list) -- a list of networks
|
270 |
+
requires_grad (bool) -- whether the networks require gradients or not
|
271 |
+
"""
|
272 |
+
if not isinstance(nets, list):
|
273 |
+
nets = [nets]
|
274 |
+
for net in nets:
|
275 |
+
if net is not None:
|
276 |
+
for param in net.parameters():
|
277 |
+
param.requires_grad = requires_grad
|
BidirectionalTranslation/models/cycle_ganstft_model.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
from .base_model import BaseModel
|
4 |
+
from . import networks
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
7 |
+
|
8 |
+
|
9 |
+
class CycleGANSTFTModel(BaseModel):
|
10 |
+
|
11 |
+
def __init__(self, opt, ckpt_root):
|
12 |
+
|
13 |
+
BaseModel.__init__(self, opt)
|
14 |
+
|
15 |
+
use_vae = True
|
16 |
+
self.interchnnls = 4
|
17 |
+
use_noise = False
|
18 |
+
self.half_size = opt.batch_size //2
|
19 |
+
self.device=opt.local_rank
|
20 |
+
self.gpu_ids=[self.device]
|
21 |
+
self.local_rank = opt.local_rank
|
22 |
+
self.cropsize = opt.crop_size
|
23 |
+
|
24 |
+
self.model_names = ['G_INTSCR2RGB','G_RGB2INTSCR','E']
|
25 |
+
self.netG_INTSCR2RGB = networks.define_G(self.interchnnls + 1, 3, opt.nz, opt.ngf, netG='unet_256',
|
26 |
+
norm='layer', nl='lrelu', use_dropout=opt.use_dropout, init_type='kaiming', init_gain=opt.init_gain,
|
27 |
+
gpu_ids=self.gpu_ids, where_add='all', upsample='bilinear', use_noise=use_noise)
|
28 |
+
self.netG_RGB2INTSCR = networks.define_G(4, self.interchnnls, 0, opt.ngf, netG='unet_256',
|
29 |
+
norm='layer', nl='lrelu', use_dropout=opt.use_dropout, init_type='kaiming', init_gain=opt.init_gain,
|
30 |
+
gpu_ids=self.gpu_ids, where_add='input', upsample='bilinear', use_noise=use_noise)
|
31 |
+
self.netE = networks.define_E(opt.output_nc, opt.nz, opt.nef, netE=opt.netE, norm='none', nl='lrelu',
|
32 |
+
init_type='xavier', init_gain=opt.init_gain, gpu_ids=self.gpu_ids, vaeLike=use_vae)
|
33 |
+
self.nets = [self.netG_INTSCR2RGB, self.netG_RGB2INTSCR, self.netE]
|
34 |
+
|
35 |
+
self.netSVAE = networks.define_SVAE(inc=1, outc=self.interchnnls, outplanes=64, blocks=3, netVAE='SVAE',
|
36 |
+
save_dir= ckpt_root+'/ScreenStyle/ScreenVAE',init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids)
|
37 |
+
|
38 |
+
|
39 |
+
def set_input(self, input):
|
40 |
+
AtoB = self.opt.direction == 'AtoB'
|
41 |
+
self.real_RGB = input['A'].to(self.device)
|
42 |
+
self.real_Ai = self.grayscale(self.real_RGB)
|
43 |
+
self.real_L = input['L'].to(self.device)
|
44 |
+
self.real_ML = input['Bl'].to(self.device)
|
45 |
+
self.real_M = input['B'].to(self.device)
|
46 |
+
|
47 |
+
self.h = input['h']
|
48 |
+
self.w = input['w']
|
49 |
+
|
50 |
+
def grayscale(self, input_image):
|
51 |
+
rate = torch.Tensor([0.299, 0.587, 0.114]).reshape(1, 3, 1, 1).to(input_image.device)
|
52 |
+
# tmp = input_image[:,0, ...] * 0.299 + input_image[:,1, ...] * 0.587 + input_image[:,2, ...] * 0.114
|
53 |
+
return (input_image*rate).sum(1,keepdims=True)
|
54 |
+
|
55 |
+
def forward(self, AtoB=True, sty=None):
|
56 |
+
if AtoB:
|
57 |
+
real_LRGB = torch.cat([self.real_L, self.real_RGB],1)
|
58 |
+
fake_SCR = self.netG_RGB2INTSCR(real_LRGB)
|
59 |
+
fake_M = self.netSVAE(fake_SCR, line=self.real_L, img_input=False)
|
60 |
+
fake_M = torch.clamp(fake_M, -1,1)
|
61 |
+
fake_M2 = self.norm(torch.mul(self.denorm(fake_M), self.denorm(self.real_L)))#*self.mask2
|
62 |
+
return fake_M[:,:,:self.h, :self.w], fake_M2[:,:,:self.h, :self.w], fake_SCR[:,:,:self.h, :self.w]
|
63 |
+
else:
|
64 |
+
if sty is None: # use encoded z
|
65 |
+
z0, _ = self.netE(self.real_RGB)
|
66 |
+
else:
|
67 |
+
z0 = sty
|
68 |
+
# z0 = self.get_z_random(self.real_A.size(0), self.opt.nz)
|
69 |
+
real_SCR = self.netSVAE(self.real_M, self.real_ML, output_screen_only=True) #8
|
70 |
+
real_LSCR = torch.cat([self.real_ML, real_SCR], 1)
|
71 |
+
fake_nRGB = self.netG_INTSCR2RGB(real_LSCR, z0)
|
72 |
+
fake_nRGB = torch.clamp(fake_nRGB, -1,1)
|
73 |
+
fake_RGB = self.norm(torch.mul(self.denorm(fake_nRGB), self.denorm(self.real_ML)))
|
74 |
+
return fake_RGB[:,:,:self.h, :self.w], real_SCR[:,:,:self.h, :self.w], self.real_ML[:,:,:self.h, :self.w]
|
75 |
+
|
76 |
+
def norm(self, im):
|
77 |
+
return im * 2.0 - 1
|
78 |
+
|
79 |
+
def denorm(self, im):
|
80 |
+
return (im + 1) / 2.0
|
81 |
+
|
82 |
+
def optimize_parameters(self):
|
83 |
+
pass
|
84 |
+
|
85 |
+
def get_z_random(self, batch_size, nz, random_type='gauss', truncation=False, tvalue=1):
|
86 |
+
z = None
|
87 |
+
if random_type == 'uni':
|
88 |
+
z = torch.rand(batch_size, nz) * 2.0 - 1.0
|
89 |
+
elif random_type == 'gauss':
|
90 |
+
z = torch.randn(batch_size, nz) * tvalue
|
91 |
+
# do the truncation trick
|
92 |
+
if truncation:
|
93 |
+
k = 0
|
94 |
+
while (k < 15 * nz):
|
95 |
+
if torch.max(z) <= tvalue:
|
96 |
+
break
|
97 |
+
zabs = torch.abs(z)
|
98 |
+
zz = torch.randn(batch_size, nz)
|
99 |
+
z[zabs > tvalue] = zz[zabs > tvalue]
|
100 |
+
k += 1
|
101 |
+
z = torch.clamp(z, -tvalue, tvalue)
|
102 |
+
|
103 |
+
return z.detach().to(self.device)
|
BidirectionalTranslation/models/networks.py
ADDED
@@ -0,0 +1,1375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import init
|
4 |
+
import functools
|
5 |
+
from torch.optim import lr_scheduler
|
6 |
+
import numpy as np
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch.nn.modules.normalization import LayerNorm
|
9 |
+
import os
|
10 |
+
from torch.nn.utils import spectral_norm
|
11 |
+
from torchvision import models
|
12 |
+
|
13 |
+
###############################################################################
|
14 |
+
# Helper functions
|
15 |
+
###############################################################################
|
16 |
+
|
17 |
+
|
18 |
+
def init_weights(net, init_type='normal', init_gain=0.02):
|
19 |
+
"""Initialize network weights.
|
20 |
+
Parameters:
|
21 |
+
net (network) -- network to be initialized
|
22 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
23 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
24 |
+
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
|
25 |
+
work better for some applications. Feel free to try yourself.
|
26 |
+
"""
|
27 |
+
def init_func(m): # define the initialization function
|
28 |
+
classname = m.__class__.__name__
|
29 |
+
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
30 |
+
if init_type == 'normal':
|
31 |
+
init.normal_(m.weight.data, 0.0, init_gain)
|
32 |
+
elif init_type == 'xavier':
|
33 |
+
init.xavier_normal_(m.weight.data, gain=init_gain)
|
34 |
+
elif init_type == 'kaiming':
|
35 |
+
#init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
36 |
+
init.kaiming_normal_(m.weight.data, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
|
37 |
+
elif init_type == 'orthogonal':
|
38 |
+
init.orthogonal_(m.weight.data, gain=init_gain)
|
39 |
+
else:
|
40 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
41 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
42 |
+
init.constant_(m.bias.data, 0.0)
|
43 |
+
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
|
44 |
+
init.normal_(m.weight.data, 1.0, init_gain)
|
45 |
+
init.constant_(m.bias.data, 0.0)
|
46 |
+
|
47 |
+
print('initialize network with %s' % init_type)
|
48 |
+
net.apply(init_func) # apply the initialization function <init_func>
|
49 |
+
|
50 |
+
|
51 |
+
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], init=True):
|
52 |
+
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
|
53 |
+
Parameters:
|
54 |
+
net (network) -- the network to be initialized
|
55 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
56 |
+
gain (float) -- scaling factor for normal, xavier and orthogonal.
|
57 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
58 |
+
Return an initialized network.
|
59 |
+
"""
|
60 |
+
if len(gpu_ids) > 0:
|
61 |
+
assert(torch.cuda.is_available())
|
62 |
+
net.to(gpu_ids[0])
|
63 |
+
if init:
|
64 |
+
init_weights(net, init_type, init_gain=init_gain)
|
65 |
+
return net
|
66 |
+
|
67 |
+
|
68 |
+
def get_scheduler(optimizer, opt):
|
69 |
+
"""Return a learning rate scheduler
|
70 |
+
Parameters:
|
71 |
+
optimizer -- the optimizer of the network
|
72 |
+
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
|
73 |
+
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
|
74 |
+
For 'linear', we keep the same learning rate for the first <opt.niter> epochs
|
75 |
+
and linearly decay the rate to zero over the next <opt.niter_decay> epochs.
|
76 |
+
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
|
77 |
+
See https://pytorch.org/docs/stable/optim.html for more details.
|
78 |
+
"""
|
79 |
+
if opt.lr_policy == 'linear':
|
80 |
+
def lambda_rule(epoch):
|
81 |
+
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
|
82 |
+
return lr_l
|
83 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
84 |
+
elif opt.lr_policy == 'step':
|
85 |
+
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
|
86 |
+
elif opt.lr_policy == 'plateau':
|
87 |
+
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
|
88 |
+
elif opt.lr_policy == 'cosine':
|
89 |
+
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
|
90 |
+
else:
|
91 |
+
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
|
92 |
+
return scheduler
|
93 |
+
|
94 |
+
class LayerNormWarpper(nn.Module):
|
95 |
+
def __init__(self, num_features):
|
96 |
+
super(LayerNormWarpper, self).__init__()
|
97 |
+
self.num_features = int(num_features)
|
98 |
+
|
99 |
+
def forward(self, x):
|
100 |
+
x = nn.LayerNorm([self.num_features, x.size()[2], x.size()[3]], elementwise_affine=False).cuda()(x)
|
101 |
+
return x
|
102 |
+
|
103 |
+
def get_norm_layer(norm_type='instance'):
|
104 |
+
"""Return a normalization layer
|
105 |
+
Parameters:
|
106 |
+
norm_type (str) -- the name of the normalization layer: batch | instance | none
|
107 |
+
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
|
108 |
+
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
|
109 |
+
"""
|
110 |
+
if norm_type == 'batch':
|
111 |
+
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
112 |
+
elif norm_type == 'instance':
|
113 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
114 |
+
elif norm_type == 'layer':
|
115 |
+
norm_layer = functools.partial(LayerNormWarpper)
|
116 |
+
elif norm_type == 'none':
|
117 |
+
norm_layer = None
|
118 |
+
else:
|
119 |
+
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
120 |
+
return norm_layer
|
121 |
+
|
122 |
+
|
123 |
+
def get_non_linearity(layer_type='relu'):
|
124 |
+
if layer_type == 'relu':
|
125 |
+
nl_layer = functools.partial(nn.ReLU, inplace=True)
|
126 |
+
elif layer_type == 'lrelu':
|
127 |
+
nl_layer = functools.partial(
|
128 |
+
nn.LeakyReLU, negative_slope=0.2, inplace=True)
|
129 |
+
elif layer_type == 'elu':
|
130 |
+
nl_layer = functools.partial(nn.ELU, inplace=True)
|
131 |
+
elif layer_type == 'selu':
|
132 |
+
nl_layer = functools.partial(nn.SELU, inplace=True)
|
133 |
+
elif layer_type == 'prelu':
|
134 |
+
nl_layer = functools.partial(nn.PReLU)
|
135 |
+
else:
|
136 |
+
raise NotImplementedError(
|
137 |
+
'nonlinearity activitation [%s] is not found' % layer_type)
|
138 |
+
return nl_layer
|
139 |
+
|
140 |
+
|
141 |
+
def define_G(input_nc, output_nc, nz, ngf, netG='unet_128', norm='batch', nl='relu', use_noise=False,
|
142 |
+
use_dropout=False, init_type='xavier', init_gain=0.02, gpu_ids=[], where_add='input', upsample='bilinear'):
|
143 |
+
net = None
|
144 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
145 |
+
nl_layer = get_non_linearity(layer_type=nl)
|
146 |
+
# print(norm, norm_layer)
|
147 |
+
|
148 |
+
if nz == 0:
|
149 |
+
where_add = 'input'
|
150 |
+
|
151 |
+
if netG == 'unet_128' and where_add == 'input':
|
152 |
+
net = G_Unet_add_input(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
|
153 |
+
use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
|
154 |
+
elif netG == 'unet_128_G' and where_add == 'input':
|
155 |
+
net = G_Unet_add_input_G(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
|
156 |
+
use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
|
157 |
+
elif netG == 'unet_256' and where_add == 'input':
|
158 |
+
net = G_Unet_add_input(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
|
159 |
+
use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
|
160 |
+
elif netG == 'unet_256_G' and where_add == 'input':
|
161 |
+
net = G_Unet_add_input_G(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
|
162 |
+
use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
|
163 |
+
elif netG == 'unet_128' and where_add == 'all':
|
164 |
+
net = G_Unet_add_all(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
|
165 |
+
use_dropout=use_dropout, upsample=upsample)
|
166 |
+
elif netG == 'unet_256' and where_add == 'all':
|
167 |
+
net = G_Unet_add_all(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
|
168 |
+
use_dropout=use_dropout, upsample=upsample)
|
169 |
+
else:
|
170 |
+
raise NotImplementedError('Generator model name [%s] is not recognized' % net)
|
171 |
+
# print(net)
|
172 |
+
return init_net(net, init_type, init_gain, gpu_ids)
|
173 |
+
|
174 |
+
|
175 |
+
def define_C(input_nc, output_nc, nz, ngf, netC='unet_128', norm='instance', nl='relu',
|
176 |
+
use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], upsample='basic'):
|
177 |
+
net = None
|
178 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
179 |
+
nl_layer = get_non_linearity(layer_type=nl)
|
180 |
+
|
181 |
+
if netC == 'resnet_9blocks':
|
182 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
|
183 |
+
elif netC == 'resnet_6blocks':
|
184 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
|
185 |
+
elif netC == 'unet_128':
|
186 |
+
net = G_Unet_add_input_C(input_nc, output_nc, 0, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
|
187 |
+
use_dropout=use_dropout, upsample=upsample)
|
188 |
+
elif netC == 'unet_256':
|
189 |
+
net = G_Unet_add_input(input_nc, output_nc, 0, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
|
190 |
+
use_dropout=use_dropout, upsample=upsample)
|
191 |
+
elif netC == 'unet_32':
|
192 |
+
net = G_Unet_add_input(input_nc, output_nc, 0, 5, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
|
193 |
+
use_dropout=use_dropout, upsample=upsample)
|
194 |
+
else:
|
195 |
+
raise NotImplementedError('Generator model name [%s] is not recognized' % net)
|
196 |
+
|
197 |
+
return init_net(net, init_type, init_gain, gpu_ids)
|
198 |
+
|
199 |
+
|
200 |
+
def define_D(input_nc, ndf, netD, norm='batch', nl='lrelu', init_type='xavier', init_gain=0.02, num_Ds=1, gpu_ids=[]):
|
201 |
+
net = None
|
202 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
203 |
+
nl = 'lrelu' # use leaky relu for D
|
204 |
+
nl_layer = get_non_linearity(layer_type=nl)
|
205 |
+
|
206 |
+
if netD == 'basic_128':
|
207 |
+
net = D_NLayers(input_nc, ndf, n_layers=2, norm_layer=norm_layer, nl_layer=nl_layer)
|
208 |
+
elif netD == 'basic_256':
|
209 |
+
net = D_NLayers(input_nc, ndf, n_layers=3, norm_layer=norm_layer, nl_layer=nl_layer)
|
210 |
+
elif netD == 'basic_128_multi':
|
211 |
+
net = D_NLayersMulti(input_nc=input_nc, ndf=ndf, n_layers=2, norm_layer=norm_layer, num_D=num_Ds, nl_layer=nl_layer)
|
212 |
+
elif netD == 'basic_256_multi':
|
213 |
+
net = D_NLayersMulti(input_nc=input_nc, ndf=ndf, n_layers=3, norm_layer=norm_layer, num_D=num_Ds, nl_layer=nl_layer)
|
214 |
+
else:
|
215 |
+
raise NotImplementedError('Discriminator model name [%s] is not recognized' % net)
|
216 |
+
return init_net(net, init_type, init_gain, gpu_ids)
|
217 |
+
|
218 |
+
|
219 |
+
def define_E(input_nc, output_nc, ndf, netE, norm='batch', nl='lrelu',
|
220 |
+
init_type='xavier', init_gain=0.02, gpu_ids=[], vaeLike=False):
|
221 |
+
net = None
|
222 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
223 |
+
nl = 'lrelu' # use leaky relu for E
|
224 |
+
nl_layer = get_non_linearity(layer_type=nl)
|
225 |
+
if netE == 'resnet_128':
|
226 |
+
net = E_ResNet(input_nc, output_nc, ndf, n_blocks=4, norm_layer=norm_layer,
|
227 |
+
nl_layer=nl_layer, vaeLike=vaeLike)
|
228 |
+
elif netE == 'resnet_256':
|
229 |
+
net = E_ResNet(input_nc, output_nc, ndf, n_blocks=5, norm_layer=norm_layer,
|
230 |
+
nl_layer=nl_layer, vaeLike=vaeLike)
|
231 |
+
elif netE == 'conv_128':
|
232 |
+
net = E_NLayers(input_nc, output_nc, ndf, n_layers=4, norm_layer=norm_layer,
|
233 |
+
nl_layer=nl_layer, vaeLike=vaeLike)
|
234 |
+
elif netE == 'conv_256':
|
235 |
+
net = E_NLayers(input_nc, output_nc, ndf, n_layers=5, norm_layer=norm_layer,
|
236 |
+
nl_layer=nl_layer, vaeLike=vaeLike)
|
237 |
+
else:
|
238 |
+
raise NotImplementedError('Encoder model name [%s] is not recognized' % net)
|
239 |
+
|
240 |
+
return init_net(net, init_type, init_gain, gpu_ids, False)
|
241 |
+
|
242 |
+
|
243 |
+
class ResnetGenerator(nn.Module):
|
244 |
+
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, norm_layer=None, use_dropout=False, n_blocks=6, padding_type='replicate'):
|
245 |
+
assert(n_blocks >= 0)
|
246 |
+
super(ResnetGenerator, self).__init__()
|
247 |
+
self.input_nc = input_nc
|
248 |
+
self.output_nc = output_nc
|
249 |
+
self.ngf = ngf
|
250 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
251 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
252 |
+
else:
|
253 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
254 |
+
|
255 |
+
model = [nn.ReplicationPad2d(3),
|
256 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
|
257 |
+
bias=use_bias)]
|
258 |
+
if norm_layer is not None:
|
259 |
+
model += [norm_layer(ngf)]
|
260 |
+
model += [nn.ReLU(True)]
|
261 |
+
|
262 |
+
# n_downsampling = 2
|
263 |
+
for i in range(n_downsampling):
|
264 |
+
mult = 2**i
|
265 |
+
model += [nn.ReplicationPad2d(1),nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
|
266 |
+
stride=2, padding=0, bias=use_bias)]
|
267 |
+
# model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
|
268 |
+
# stride=2, padding=1, bias=use_bias)]
|
269 |
+
if norm_layer is not None:
|
270 |
+
model += [norm_layer(ngf * mult * 2)]
|
271 |
+
model += [nn.ReLU(True)]
|
272 |
+
|
273 |
+
mult = 2**n_downsampling
|
274 |
+
for i in range(n_blocks):
|
275 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
276 |
+
|
277 |
+
for i in range(n_downsampling):
|
278 |
+
mult = 2**(n_downsampling - i)
|
279 |
+
# model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
280 |
+
# kernel_size=3, stride=2,
|
281 |
+
# padding=1, output_padding=1,
|
282 |
+
# bias=use_bias)]
|
283 |
+
# if norm_layer is not None:
|
284 |
+
# model += [norm_layer(ngf * mult / 2)]
|
285 |
+
# model += [nn.ReLU(True)]
|
286 |
+
model += upsampleLayer(ngf * mult, int(ngf * mult / 2), upsample='bilinear', padding_type=padding_type)
|
287 |
+
if norm_layer is not None:
|
288 |
+
model += [norm_layer(int(ngf * mult / 2))]
|
289 |
+
model += [nn.ReLU(True)]
|
290 |
+
model +=[nn.ReplicationPad2d(1),
|
291 |
+
nn.Conv2d(int(ngf * mult / 2), int(ngf * mult / 2), kernel_size=3, padding=0)]
|
292 |
+
if norm_layer is not None:
|
293 |
+
model += [norm_layer(ngf * mult / 2)]
|
294 |
+
model += [nn.ReLU(True)]
|
295 |
+
model += [nn.ReplicationPad2d(3)]
|
296 |
+
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
297 |
+
#model += [nn.Tanh()]
|
298 |
+
|
299 |
+
self.model = nn.Sequential(*model)
|
300 |
+
|
301 |
+
def forward(self, input):
|
302 |
+
return self.model(input)
|
303 |
+
|
304 |
+
|
305 |
+
# Define a resnet block
|
306 |
+
class ResnetBlock(nn.Module):
|
307 |
+
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
308 |
+
super(ResnetBlock, self).__init__()
|
309 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
|
310 |
+
|
311 |
+
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
312 |
+
conv_block = []
|
313 |
+
p = 0
|
314 |
+
if padding_type == 'reflect':
|
315 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
316 |
+
elif padding_type == 'replicate':
|
317 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
318 |
+
elif padding_type == 'zero':
|
319 |
+
p = 1
|
320 |
+
else:
|
321 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
322 |
+
|
323 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)]
|
324 |
+
if norm_layer is not None:
|
325 |
+
conv_block += [norm_layer(dim)]
|
326 |
+
conv_block += [nn.ReLU(True)]
|
327 |
+
# if use_dropout:
|
328 |
+
# conv_block += [nn.Dropout(0.5)]
|
329 |
+
|
330 |
+
p = 0
|
331 |
+
if padding_type == 'reflect':
|
332 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
333 |
+
elif padding_type == 'replicate':
|
334 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
335 |
+
elif padding_type == 'zero':
|
336 |
+
p = 1
|
337 |
+
else:
|
338 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
339 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)]
|
340 |
+
if norm_layer is not None:
|
341 |
+
conv_block += [norm_layer(dim)]
|
342 |
+
|
343 |
+
return nn.Sequential(*conv_block)
|
344 |
+
|
345 |
+
def forward(self, x):
|
346 |
+
out = x + self.conv_block(x)
|
347 |
+
return out
|
348 |
+
|
349 |
+
|
350 |
+
class D_NLayersMulti(nn.Module):
|
351 |
+
def __init__(self, input_nc, ndf=64, n_layers=3,
|
352 |
+
norm_layer=nn.BatchNorm2d, num_D=1, nl_layer=None):
|
353 |
+
super(D_NLayersMulti, self).__init__()
|
354 |
+
# st()
|
355 |
+
self.num_D = num_D
|
356 |
+
self.nl_layer=nl_layer
|
357 |
+
if num_D == 1:
|
358 |
+
layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
|
359 |
+
self.model = nn.Sequential(*layers)
|
360 |
+
else:
|
361 |
+
layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
|
362 |
+
self.add_module("model_0", nn.Sequential(*layers))
|
363 |
+
self.down = nn.functional.interpolate
|
364 |
+
for i in range(1, num_D):
|
365 |
+
ndf_i = int(round(ndf / (2**i)))
|
366 |
+
layers = self.get_layers(input_nc, ndf_i, n_layers, norm_layer)
|
367 |
+
self.add_module("model_%d" % i, nn.Sequential(*layers))
|
368 |
+
|
369 |
+
def get_layers(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
|
370 |
+
kw = 3
|
371 |
+
padw = 1
|
372 |
+
sequence = [spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw,
|
373 |
+
stride=2, padding=padw)), nn.LeakyReLU(0.2, True)]
|
374 |
+
|
375 |
+
nf_mult = 1
|
376 |
+
nf_mult_prev = 1
|
377 |
+
for n in range(1, n_layers):
|
378 |
+
nf_mult_prev = nf_mult
|
379 |
+
nf_mult = min(2**n, 8)
|
380 |
+
sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
381 |
+
kernel_size=kw, stride=2, padding=padw))]
|
382 |
+
if norm_layer:
|
383 |
+
sequence += [norm_layer(ndf * nf_mult)]
|
384 |
+
|
385 |
+
sequence += [self.nl_layer()]
|
386 |
+
|
387 |
+
nf_mult_prev = nf_mult
|
388 |
+
nf_mult = min(2**n_layers, 8)
|
389 |
+
sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
390 |
+
kernel_size=kw, stride=1, padding=padw))]
|
391 |
+
if norm_layer:
|
392 |
+
sequence += [norm_layer(ndf * nf_mult)]
|
393 |
+
sequence += [self.nl_layer()]
|
394 |
+
|
395 |
+
sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult, 1,
|
396 |
+
kernel_size=kw, stride=1, padding=padw))]
|
397 |
+
|
398 |
+
return sequence
|
399 |
+
|
400 |
+
def forward(self, input):
|
401 |
+
if self.num_D == 1:
|
402 |
+
return self.model(input)
|
403 |
+
result = []
|
404 |
+
down = input
|
405 |
+
for i in range(self.num_D):
|
406 |
+
model = getattr(self, "model_%d" % i)
|
407 |
+
result.append(model(down))
|
408 |
+
if i != self.num_D - 1:
|
409 |
+
down = self.down(down, scale_factor=0.5, mode='bilinear')
|
410 |
+
return result
|
411 |
+
|
412 |
+
class D_NLayers(nn.Module):
|
413 |
+
"""Defines a PatchGAN discriminator"""
|
414 |
+
|
415 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
|
416 |
+
"""Construct a PatchGAN discriminator
|
417 |
+
Parameters:
|
418 |
+
input_nc (int) -- the number of channels in input images
|
419 |
+
ndf (int) -- the number of filters in the last conv layer
|
420 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
421 |
+
norm_layer -- normalization layer
|
422 |
+
"""
|
423 |
+
super(D_NLayers, self).__init__()
|
424 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
425 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
426 |
+
else:
|
427 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
428 |
+
|
429 |
+
kw = 3
|
430 |
+
padw = 1
|
431 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
432 |
+
nf_mult = 1
|
433 |
+
nf_mult_prev = 1
|
434 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
435 |
+
nf_mult_prev = nf_mult
|
436 |
+
nf_mult = min(2 ** n, 8)
|
437 |
+
sequence += [
|
438 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
439 |
+
norm_layer(ndf * nf_mult),
|
440 |
+
nn.LeakyReLU(0.2, True)
|
441 |
+
]
|
442 |
+
|
443 |
+
nf_mult_prev = nf_mult
|
444 |
+
nf_mult = min(2 ** n_layers, 8)
|
445 |
+
sequence += [
|
446 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
447 |
+
norm_layer(ndf * nf_mult),
|
448 |
+
nn.LeakyReLU(0.2, True)
|
449 |
+
]
|
450 |
+
|
451 |
+
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
452 |
+
self.model = nn.Sequential(*sequence)
|
453 |
+
|
454 |
+
def forward(self, input):
|
455 |
+
"""Standard forward."""
|
456 |
+
return self.model(input)
|
457 |
+
|
458 |
+
|
459 |
+
class G_Unet_add_input(nn.Module):
|
460 |
+
def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
|
461 |
+
norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False,
|
462 |
+
upsample='basic', device=0):
|
463 |
+
super(G_Unet_add_input, self).__init__()
|
464 |
+
self.nz = nz
|
465 |
+
max_nchn = 8
|
466 |
+
noise = []
|
467 |
+
for i in range(num_downs+1):
|
468 |
+
if use_noise:
|
469 |
+
noise.append(True)
|
470 |
+
else:
|
471 |
+
noise.append(False)
|
472 |
+
|
473 |
+
# construct unet structure
|
474 |
+
#print(num_downs)
|
475 |
+
unet_block = UnetBlock_A(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, noise=noise[num_downs-1],
|
476 |
+
innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
|
477 |
+
for i in range(num_downs - 5):
|
478 |
+
unet_block = UnetBlock_A(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, noise[num_downs-i-3],
|
479 |
+
norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
|
480 |
+
unet_block = UnetBlock_A(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, noise[2],
|
481 |
+
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
|
482 |
+
unet_block = UnetBlock_A(ngf * 2, ngf * 2, ngf * 4, unet_block, noise[1],
|
483 |
+
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
|
484 |
+
unet_block = UnetBlock_A(ngf, ngf, ngf * 2, unet_block, noise[0],
|
485 |
+
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
|
486 |
+
unet_block = UnetBlock_A(input_nc + nz, output_nc, ngf, unet_block, None,
|
487 |
+
outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
|
488 |
+
|
489 |
+
self.model = unet_block
|
490 |
+
|
491 |
+
def forward(self, x, z=None):
|
492 |
+
if self.nz > 0:
|
493 |
+
z_img = z.view(z.size(0), z.size(1), 1, 1).expand(
|
494 |
+
z.size(0), z.size(1), x.size(2), x.size(3))
|
495 |
+
x_with_z = torch.cat([x, z_img], 1)
|
496 |
+
else:
|
497 |
+
x_with_z = x # no z
|
498 |
+
|
499 |
+
|
500 |
+
return torch.tanh(self.model(x_with_z))
|
501 |
+
# return self.model(x_with_z)
|
502 |
+
|
503 |
+
class G_Unet_add_input_G(nn.Module):
|
504 |
+
def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
|
505 |
+
norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False,
|
506 |
+
upsample='basic', device=0):
|
507 |
+
super(G_Unet_add_input_G, self).__init__()
|
508 |
+
self.nz = nz
|
509 |
+
max_nchn = 8
|
510 |
+
noise = []
|
511 |
+
for i in range(num_downs+1):
|
512 |
+
if use_noise:
|
513 |
+
noise.append(True)
|
514 |
+
else:
|
515 |
+
noise.append(False)
|
516 |
+
# construct unet structure
|
517 |
+
#print(num_downs)
|
518 |
+
unet_block = UnetBlock_G(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, noise=False,
|
519 |
+
innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
|
520 |
+
for i in range(num_downs - 5):
|
521 |
+
unet_block = UnetBlock_G(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, noise=False,
|
522 |
+
norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
|
523 |
+
unet_block = UnetBlock_G(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, noise[2],
|
524 |
+
norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
|
525 |
+
unet_block = UnetBlock_G(ngf * 2, ngf * 2, ngf * 4, unet_block, noise[1],
|
526 |
+
norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
|
527 |
+
unet_block = UnetBlock_G(ngf, ngf, ngf * 2, unet_block, noise[0],
|
528 |
+
norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
|
529 |
+
unet_block = UnetBlock_G(input_nc + nz, output_nc, ngf, unet_block, None,
|
530 |
+
outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
|
531 |
+
|
532 |
+
self.model = unet_block
|
533 |
+
|
534 |
+
def forward(self, x, z=None):
|
535 |
+
if self.nz > 0:
|
536 |
+
z_img = z.view(z.size(0), z.size(1), 1, 1).expand(
|
537 |
+
z.size(0), z.size(1), x.size(2), x.size(3))
|
538 |
+
x_with_z = torch.cat([x, z_img], 1)
|
539 |
+
else:
|
540 |
+
x_with_z = x # no z
|
541 |
+
|
542 |
+
# return F.tanh(self.model(x_with_z))
|
543 |
+
return self.model(x_with_z)
|
544 |
+
|
545 |
+
class G_Unet_add_input_C(nn.Module):
|
546 |
+
def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
|
547 |
+
norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False,
|
548 |
+
upsample='basic', device=0):
|
549 |
+
super(G_Unet_add_input_C, self).__init__()
|
550 |
+
self.nz = nz
|
551 |
+
max_nchn = 8
|
552 |
+
# construct unet structure
|
553 |
+
#print(num_downs)
|
554 |
+
unet_block = UnetBlock(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, noise=False,
|
555 |
+
innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
|
556 |
+
for i in range(num_downs - 5):
|
557 |
+
unet_block = UnetBlock(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, noise=False,
|
558 |
+
norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
|
559 |
+
unet_block = UnetBlock(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, noise=False,
|
560 |
+
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
|
561 |
+
unet_block = UnetBlock(ngf * 2, ngf * 2, ngf * 4, unet_block, noise=False,
|
562 |
+
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
|
563 |
+
unet_block = UnetBlock(ngf, ngf, ngf * 2, unet_block, noise=False,
|
564 |
+
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
|
565 |
+
unet_block = UnetBlock(input_nc + nz, output_nc, ngf, unet_block, noise=False,
|
566 |
+
outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
|
567 |
+
|
568 |
+
self.model = unet_block
|
569 |
+
|
570 |
+
def forward(self, x, z=None):
|
571 |
+
if self.nz > 0:
|
572 |
+
z_img = z.view(z.size(0), z.size(1), 1, 1).expand(
|
573 |
+
z.size(0), z.size(1), x.size(2), x.size(3))
|
574 |
+
x_with_z = torch.cat([x, z_img], 1)
|
575 |
+
else:
|
576 |
+
x_with_z = x # no z
|
577 |
+
|
578 |
+
# return torch.tanh(self.model(x_with_z))
|
579 |
+
return self.model(x_with_z)
|
580 |
+
|
581 |
+
def upsampleLayer(inplanes, outplanes, kw=1, upsample='basic', padding_type='replicate'):
|
582 |
+
# padding_type = 'zero'
|
583 |
+
if upsample == 'basic':
|
584 |
+
upconv = [nn.ConvTranspose2d(inplanes, outplanes, kernel_size=4, stride=2, padding=1)]#, padding_mode='replicate'
|
585 |
+
elif upsample == 'bilinear' or upsample == 'nearest' or upsample == 'linear':
|
586 |
+
upconv = [nn.Upsample(scale_factor=2, mode=upsample, align_corners=True),
|
587 |
+
#nn.ReplicationPad2d(1),
|
588 |
+
nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)]
|
589 |
+
# p = kw//2
|
590 |
+
# upconv = [nn.Upsample(scale_factor=2, mode=upsample, align_corners=True),
|
591 |
+
# nn.Conv2d(inplanes, outplanes, kernel_size=kw, stride=1, padding=p, padding_mode='replicate')]
|
592 |
+
else:
|
593 |
+
raise NotImplementedError(
|
594 |
+
'upsample layer [%s] not implemented' % upsample)
|
595 |
+
return upconv
|
596 |
+
|
597 |
+
class UnetBlock_G(nn.Module):
|
598 |
+
def __init__(self, input_nc, outer_nc, inner_nc,
|
599 |
+
submodule=None, noise=None, outermost=False, innermost=False,
|
600 |
+
norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='replicate'):
|
601 |
+
super(UnetBlock_G, self).__init__()
|
602 |
+
self.outermost = outermost
|
603 |
+
p = 0
|
604 |
+
downconv = []
|
605 |
+
if padding_type == 'reflect':
|
606 |
+
downconv += [nn.ReflectionPad2d(1)]
|
607 |
+
elif padding_type == 'replicate':
|
608 |
+
downconv += [nn.ReplicationPad2d(1)]
|
609 |
+
elif padding_type == 'zero':
|
610 |
+
p = 1
|
611 |
+
else:
|
612 |
+
raise NotImplementedError(
|
613 |
+
'padding [%s] is not implemented' % padding_type)
|
614 |
+
|
615 |
+
downconv += [nn.Conv2d(input_nc, inner_nc,
|
616 |
+
kernel_size=3, stride=2, padding=p)]
|
617 |
+
# downsample is different from upsample
|
618 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
619 |
+
downnorm = norm_layer(inner_nc) if norm_layer is not None else None
|
620 |
+
uprelu = nl_layer()
|
621 |
+
uprelu2 = nl_layer()
|
622 |
+
uppad = nn.ReplicationPad2d(1)
|
623 |
+
upnorm = norm_layer(outer_nc) if norm_layer is not None else None
|
624 |
+
upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
|
625 |
+
self.noiseblock = ApplyNoise(outer_nc)
|
626 |
+
self.noise = noise
|
627 |
+
|
628 |
+
if outermost:
|
629 |
+
upconv = upsampleLayer(inner_nc * 2, inner_nc, upsample=upsample, padding_type=padding_type)
|
630 |
+
uppad = nn.ReplicationPad2d(3)
|
631 |
+
upconv2 = nn.Conv2d(inner_nc, outer_nc, kernel_size=7, padding=0)
|
632 |
+
down = downconv
|
633 |
+
up = [uprelu] + upconv
|
634 |
+
if upnorm is not None:
|
635 |
+
up += [norm_layer(inner_nc)]
|
636 |
+
# upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
|
637 |
+
# upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=0)
|
638 |
+
# down = downconv
|
639 |
+
# up = [uprelu] + upconv
|
640 |
+
# if upnorm is not None:
|
641 |
+
# up += [norm_layer(outer_nc)]
|
642 |
+
up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
|
643 |
+
model = down + [submodule] + up
|
644 |
+
elif innermost:
|
645 |
+
upconv = upsampleLayer(inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
|
646 |
+
upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
|
647 |
+
down = [downrelu] + downconv
|
648 |
+
up = [uprelu] + upconv
|
649 |
+
if upnorm is not None:
|
650 |
+
up += [upnorm]
|
651 |
+
up += [uprelu2, uppad, upconv2]
|
652 |
+
if upnorm2 is not None:
|
653 |
+
up += [upnorm2]
|
654 |
+
model = down + up
|
655 |
+
else:
|
656 |
+
upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
|
657 |
+
upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
|
658 |
+
down = [downrelu] + downconv
|
659 |
+
if downnorm is not None:
|
660 |
+
down += [downnorm]
|
661 |
+
up = [uprelu] + upconv
|
662 |
+
if upnorm is not None:
|
663 |
+
up += [upnorm]
|
664 |
+
up += [uprelu2, uppad, upconv2]
|
665 |
+
if upnorm2 is not None:
|
666 |
+
up += [upnorm2]
|
667 |
+
|
668 |
+
if use_dropout:
|
669 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
670 |
+
else:
|
671 |
+
model = down + [submodule] + up
|
672 |
+
|
673 |
+
self.model = nn.Sequential(*model)
|
674 |
+
|
675 |
+
def forward(self, x):
|
676 |
+
if self.outermost:
|
677 |
+
return self.model(x)
|
678 |
+
else:
|
679 |
+
x2 = self.model(x)
|
680 |
+
if self.noise:
|
681 |
+
x2 = self.noiseblock(x2, self.noise)
|
682 |
+
return torch.cat([x2, x], 1)
|
683 |
+
|
684 |
+
|
685 |
+
class UnetBlock(nn.Module):
|
686 |
+
def __init__(self, input_nc, outer_nc, inner_nc,
|
687 |
+
submodule=None, noise=None, outermost=False, innermost=False,
|
688 |
+
norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='replicate'):
|
689 |
+
super(UnetBlock, self).__init__()
|
690 |
+
self.outermost = outermost
|
691 |
+
p = 0
|
692 |
+
downconv = []
|
693 |
+
if padding_type == 'reflect':
|
694 |
+
downconv += [nn.ReflectionPad2d(1)]
|
695 |
+
elif padding_type == 'replicate':
|
696 |
+
downconv += [nn.ReplicationPad2d(1)]
|
697 |
+
elif padding_type == 'zero':
|
698 |
+
p = 1
|
699 |
+
else:
|
700 |
+
raise NotImplementedError(
|
701 |
+
'padding [%s] is not implemented' % padding_type)
|
702 |
+
|
703 |
+
downconv += [nn.Conv2d(input_nc, inner_nc,
|
704 |
+
kernel_size=3, stride=2, padding=p)]
|
705 |
+
# downsample is different from upsample
|
706 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
707 |
+
downnorm = norm_layer(inner_nc) if norm_layer is not None else None
|
708 |
+
uprelu = nl_layer()
|
709 |
+
uprelu2 = nl_layer()
|
710 |
+
uppad = nn.ReplicationPad2d(1)
|
711 |
+
upnorm = norm_layer(outer_nc) if norm_layer is not None else None
|
712 |
+
upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
|
713 |
+
self.noiseblock = ApplyNoise(outer_nc)
|
714 |
+
self.noise = noise
|
715 |
+
|
716 |
+
if outermost:
|
717 |
+
upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
|
718 |
+
upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
|
719 |
+
down = downconv
|
720 |
+
up = [uprelu] + upconv
|
721 |
+
if upnorm is not None:
|
722 |
+
up += [upnorm]
|
723 |
+
up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
|
724 |
+
model = down + [submodule] + up
|
725 |
+
elif innermost:
|
726 |
+
upconv = upsampleLayer(inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
|
727 |
+
upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
|
728 |
+
down = [downrelu] + downconv
|
729 |
+
up = [uprelu] + upconv
|
730 |
+
if upnorm is not None:
|
731 |
+
up += [upnorm]
|
732 |
+
up += [uprelu2, uppad, upconv2]
|
733 |
+
if upnorm2 is not None:
|
734 |
+
up += [upnorm2]
|
735 |
+
model = down + up
|
736 |
+
else:
|
737 |
+
upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
|
738 |
+
upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
|
739 |
+
down = [downrelu] + downconv
|
740 |
+
if downnorm is not None:
|
741 |
+
down += [downnorm]
|
742 |
+
up = [uprelu] + upconv
|
743 |
+
if upnorm is not None:
|
744 |
+
up += [upnorm]
|
745 |
+
up += [uprelu2, uppad, upconv2]
|
746 |
+
if upnorm2 is not None:
|
747 |
+
up += [upnorm2]
|
748 |
+
|
749 |
+
if use_dropout:
|
750 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
751 |
+
else:
|
752 |
+
model = down + [submodule] + up
|
753 |
+
|
754 |
+
self.model = nn.Sequential(*model)
|
755 |
+
|
756 |
+
def forward(self, x):
|
757 |
+
if self.outermost:
|
758 |
+
return self.model(x)
|
759 |
+
else:
|
760 |
+
x2 = self.model(x)
|
761 |
+
if self.noise:
|
762 |
+
x2 = self.noiseblock(x2, self.noise)
|
763 |
+
return torch.cat([x2, x], 1)
|
764 |
+
|
765 |
+
# Defines the submodule with skip connection.
|
766 |
+
# X -------------------identity---------------------- X
|
767 |
+
# |-- downsampling -- |submodule| -- upsampling --|
|
768 |
+
class UnetBlock_A(nn.Module):
|
769 |
+
def __init__(self, input_nc, outer_nc, inner_nc,
|
770 |
+
submodule=None, noise=None, outermost=False, innermost=False,
|
771 |
+
norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='replicate'):
|
772 |
+
super(UnetBlock_A, self).__init__()
|
773 |
+
self.outermost = outermost
|
774 |
+
p = 0
|
775 |
+
downconv = []
|
776 |
+
if padding_type == 'reflect':
|
777 |
+
downconv += [nn.ReflectionPad2d(1)]
|
778 |
+
elif padding_type == 'replicate':
|
779 |
+
downconv += [nn.ReplicationPad2d(1)]
|
780 |
+
elif padding_type == 'zero':
|
781 |
+
p = 1
|
782 |
+
else:
|
783 |
+
raise NotImplementedError(
|
784 |
+
'padding [%s] is not implemented' % padding_type)
|
785 |
+
|
786 |
+
downconv += [spectral_norm(nn.Conv2d(input_nc, inner_nc,
|
787 |
+
kernel_size=3, stride=2, padding=p))]
|
788 |
+
# downsample is different from upsample
|
789 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
790 |
+
downnorm = norm_layer(inner_nc) if norm_layer is not None else None
|
791 |
+
uprelu = nl_layer()
|
792 |
+
uprelu2 = nl_layer()
|
793 |
+
uppad = nn.ReplicationPad2d(1)
|
794 |
+
upnorm = norm_layer(outer_nc) if norm_layer is not None else None
|
795 |
+
upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
|
796 |
+
self.noiseblock = ApplyNoise(outer_nc)
|
797 |
+
self.noise = noise
|
798 |
+
|
799 |
+
if outermost:
|
800 |
+
upconv = upsampleLayer(inner_nc * 1, outer_nc, upsample=upsample, padding_type=padding_type)
|
801 |
+
upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
|
802 |
+
down = downconv
|
803 |
+
up = [uprelu] + upconv
|
804 |
+
if upnorm is not None:
|
805 |
+
up += [upnorm]
|
806 |
+
up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
|
807 |
+
model = down + [submodule] + up
|
808 |
+
elif innermost:
|
809 |
+
upconv = upsampleLayer(inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
|
810 |
+
upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
|
811 |
+
down = [downrelu] + downconv
|
812 |
+
up = [uprelu] + upconv
|
813 |
+
if upnorm is not None:
|
814 |
+
up += [upnorm]
|
815 |
+
up += [uprelu2, uppad, upconv2]
|
816 |
+
if upnorm2 is not None:
|
817 |
+
up += [upnorm2]
|
818 |
+
model = down + up
|
819 |
+
else:
|
820 |
+
upconv = upsampleLayer(inner_nc * 1, outer_nc, upsample=upsample, padding_type=padding_type)
|
821 |
+
upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
|
822 |
+
down = [downrelu] + downconv
|
823 |
+
if downnorm is not None:
|
824 |
+
down += [downnorm]
|
825 |
+
up = [uprelu] + upconv
|
826 |
+
if upnorm is not None:
|
827 |
+
up += [upnorm]
|
828 |
+
up += [uprelu2, uppad, upconv2]
|
829 |
+
if upnorm2 is not None:
|
830 |
+
up += [upnorm2]
|
831 |
+
|
832 |
+
if use_dropout:
|
833 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
834 |
+
else:
|
835 |
+
model = down + [submodule] + up
|
836 |
+
|
837 |
+
self.model = nn.Sequential(*model)
|
838 |
+
|
839 |
+
def forward(self, x):
|
840 |
+
if self.outermost:
|
841 |
+
return self.model(x)
|
842 |
+
else:
|
843 |
+
x2 = self.model(x)
|
844 |
+
if self.noise:
|
845 |
+
x2 = self.noiseblock(x2, self.noise)
|
846 |
+
if x2.shape[-1]==x.shape[-1]:
|
847 |
+
return x2 + x
|
848 |
+
else:
|
849 |
+
x2 = F.interpolate(x2, x.shape[2:])
|
850 |
+
return x2 + x
|
851 |
+
|
852 |
+
|
853 |
+
class E_ResNet(nn.Module):
|
854 |
+
def __init__(self, input_nc=3, output_nc=1, ndf=64, n_blocks=4,
|
855 |
+
norm_layer=None, nl_layer=None, vaeLike=False):
|
856 |
+
super(E_ResNet, self).__init__()
|
857 |
+
self.vaeLike = vaeLike
|
858 |
+
max_ndf = 4
|
859 |
+
conv_layers = [
|
860 |
+
nn.Conv2d(input_nc, ndf, kernel_size=3, stride=2, padding=1, bias=True)]
|
861 |
+
for n in range(1, n_blocks):
|
862 |
+
input_ndf = ndf * min(max_ndf, n)
|
863 |
+
output_ndf = ndf * min(max_ndf, n + 1)
|
864 |
+
conv_layers += [BasicBlock(input_ndf,
|
865 |
+
output_ndf, norm_layer, nl_layer)]
|
866 |
+
conv_layers += [nl_layer(), nn.AdaptiveAvgPool2d(4)]
|
867 |
+
if vaeLike:
|
868 |
+
self.fc = nn.Sequential(*[nn.Linear(output_ndf * 16, output_nc)])
|
869 |
+
self.fcVar = nn.Sequential(*[nn.Linear(output_ndf * 16, output_nc)])
|
870 |
+
else:
|
871 |
+
self.fc = nn.Sequential(*[nn.Linear(output_ndf * 16, output_nc)])
|
872 |
+
self.conv = nn.Sequential(*conv_layers)
|
873 |
+
|
874 |
+
def forward(self, x):
|
875 |
+
x_conv = self.conv(x)
|
876 |
+
conv_flat = x_conv.view(x.size(0), -1)
|
877 |
+
output = self.fc(conv_flat)
|
878 |
+
if self.vaeLike:
|
879 |
+
outputVar = self.fcVar(conv_flat)
|
880 |
+
return output, outputVar
|
881 |
+
else:
|
882 |
+
return output
|
883 |
+
return output
|
884 |
+
|
885 |
+
|
886 |
+
# Defines the Unet generator.
|
887 |
+
# |num_downs|: number of downsamplings in UNet. For example,
|
888 |
+
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
|
889 |
+
# at the bottleneck
|
890 |
+
class G_Unet_add_all(nn.Module):
|
891 |
+
def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
|
892 |
+
norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False, upsample='basic'):
|
893 |
+
super(G_Unet_add_all, self).__init__()
|
894 |
+
self.nz = nz
|
895 |
+
self.mapping = G_mapping(self.nz, self.nz, 512, normalize_latents=False, lrmul=1)
|
896 |
+
self.truncation_psi = 0
|
897 |
+
self.truncation_cutoff = 0
|
898 |
+
|
899 |
+
# - 2 means we start from feature map with height and width equals 4.
|
900 |
+
# as this example, we get num_layers = 18.
|
901 |
+
num_layers = int(np.log2(512)) * 2 - 2
|
902 |
+
# Noise inputs.
|
903 |
+
self.noise_inputs = []
|
904 |
+
for layer_idx in range(num_layers):
|
905 |
+
res = layer_idx // 2 + 2
|
906 |
+
shape = [1, 1, 2 ** res, 2 ** res]
|
907 |
+
self.noise_inputs.append(torch.randn(*shape).to("cuda"))
|
908 |
+
|
909 |
+
# construct unet structure
|
910 |
+
unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=None, innermost=True,
|
911 |
+
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
|
912 |
+
unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=unet_block,
|
913 |
+
norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
|
914 |
+
for i in range(num_downs - 6):
|
915 |
+
unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=unet_block,
|
916 |
+
norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
|
917 |
+
unet_block = UnetBlock_with_z(ngf * 4, ngf * 4, ngf * 8, nz, submodule=unet_block,
|
918 |
+
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
|
919 |
+
unet_block = UnetBlock_with_z(ngf * 2, ngf * 2, ngf * 4, nz, submodule=unet_block,
|
920 |
+
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
|
921 |
+
unet_block = UnetBlock_with_z(ngf, ngf, ngf * 2, nz, submodule=unet_block,
|
922 |
+
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
|
923 |
+
unet_block = UnetBlock_with_z(input_nc, output_nc, ngf, nz, submodule=unet_block,
|
924 |
+
outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
|
925 |
+
self.model = unet_block
|
926 |
+
|
927 |
+
def forward(self, x, z):
|
928 |
+
|
929 |
+
dlatents1, num_layers = self.mapping(z)
|
930 |
+
dlatents1 = dlatents1.unsqueeze(1)
|
931 |
+
dlatents1 = dlatents1.expand(-1, int(num_layers), -1)
|
932 |
+
|
933 |
+
# Apply truncation trick.
|
934 |
+
if self.truncation_psi and self.truncation_cutoff:
|
935 |
+
coefs = np.ones([1, num_layers, 1], dtype=np.float32)
|
936 |
+
for i in range(num_layers):
|
937 |
+
if i < self.truncation_cutoff:
|
938 |
+
coefs[:, i, :] *= self.truncation_psi
|
939 |
+
"""Linear interpolation.
|
940 |
+
a + (b - a) * t (a = 0)
|
941 |
+
reduce to
|
942 |
+
b * t
|
943 |
+
"""
|
944 |
+
dlatents1 = dlatents1 * torch.Tensor(coefs).to(dlatents1.device)
|
945 |
+
|
946 |
+
return torch.tanh(self.model(x, dlatents1, self.noise_inputs))
|
947 |
+
|
948 |
+
|
949 |
+
class ApplyNoise(nn.Module):
|
950 |
+
def __init__(self, channels):
|
951 |
+
super().__init__()
|
952 |
+
self.channels = channels
|
953 |
+
self.weight = nn.Parameter(torch.randn(channels), requires_grad=True)
|
954 |
+
self.bias = nn.Parameter(torch.zeros(channels), requires_grad=True)
|
955 |
+
|
956 |
+
def forward(self, x, noise):
|
957 |
+
W,_ = torch.split(self.weight.view(1, -1, 1, 1), self.channels // 2, dim=1)
|
958 |
+
B,_ = torch.split(self.bias.view(1, -1, 1, 1), self.channels // 2, dim=1)
|
959 |
+
Z = torch.zeros_like(W)
|
960 |
+
w = torch.cat([W,Z], dim=1).to(x.device)
|
961 |
+
b = torch.cat([B,Z], dim=1).to(x.device)
|
962 |
+
adds = w * torch.randn_like(x) + b
|
963 |
+
return x + adds.type_as(x)
|
964 |
+
|
965 |
+
|
966 |
+
class FC(nn.Module):
|
967 |
+
def __init__(self,
|
968 |
+
in_channels,
|
969 |
+
out_channels,
|
970 |
+
gain=2**(0.5),
|
971 |
+
use_wscale=False,
|
972 |
+
lrmul=1.0,
|
973 |
+
bias=True):
|
974 |
+
"""
|
975 |
+
The complete conversion of Dense/FC/Linear Layer of original Tensorflow version.
|
976 |
+
"""
|
977 |
+
super(FC, self).__init__()
|
978 |
+
he_std = gain * in_channels ** (-0.5) # He init
|
979 |
+
if use_wscale:
|
980 |
+
init_std = 1.0 / lrmul
|
981 |
+
self.w_lrmul = he_std * lrmul
|
982 |
+
else:
|
983 |
+
init_std = he_std / lrmul
|
984 |
+
self.w_lrmul = lrmul
|
985 |
+
|
986 |
+
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels) * init_std)
|
987 |
+
if bias:
|
988 |
+
self.bias = torch.nn.Parameter(torch.zeros(out_channels))
|
989 |
+
self.b_lrmul = lrmul
|
990 |
+
else:
|
991 |
+
self.bias = None
|
992 |
+
|
993 |
+
def forward(self, x):
|
994 |
+
if self.bias is not None:
|
995 |
+
out = F.linear(x, self.weight * self.w_lrmul, self.bias * self.b_lrmul)
|
996 |
+
else:
|
997 |
+
out = F.linear(x, self.weight * self.w_lrmul)
|
998 |
+
out = F.leaky_relu(out, 0.2, inplace=True)
|
999 |
+
return out
|
1000 |
+
|
1001 |
+
|
1002 |
+
class ApplyStyle(nn.Module):
|
1003 |
+
"""
|
1004 |
+
@ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
|
1005 |
+
"""
|
1006 |
+
def __init__(self, latent_size, channels, use_wscale, nl_layer):
|
1007 |
+
super(ApplyStyle, self).__init__()
|
1008 |
+
modules = [nn.Linear(latent_size, channels*2)]
|
1009 |
+
if nl_layer:
|
1010 |
+
modules += [nl_layer()]
|
1011 |
+
self.linear = nn.Sequential(*modules)
|
1012 |
+
|
1013 |
+
def forward(self, x, latent):
|
1014 |
+
style = self.linear(latent) # style => [batch_size, n_channels*2]
|
1015 |
+
shape = [-1, 2, x.size(1), 1, 1]
|
1016 |
+
style = style.view(shape) # [batch_size, 2, n_channels, ...]
|
1017 |
+
x = x * (style[:, 0] + 1.) + style[:, 1]
|
1018 |
+
return x
|
1019 |
+
|
1020 |
+
class PixelNorm(nn.Module):
|
1021 |
+
def __init__(self, epsilon=1e-8):
|
1022 |
+
"""
|
1023 |
+
@notice: avoid in-place ops.
|
1024 |
+
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
|
1025 |
+
"""
|
1026 |
+
super(PixelNorm, self).__init__()
|
1027 |
+
self.epsilon = epsilon
|
1028 |
+
|
1029 |
+
def forward(self, x):
|
1030 |
+
tmp = torch.mul(x, x) # or x ** 2
|
1031 |
+
tmp1 = torch.rsqrt(torch.mean(tmp, dim=1, keepdim=True) + self.epsilon)
|
1032 |
+
|
1033 |
+
return x * tmp1
|
1034 |
+
|
1035 |
+
|
1036 |
+
class InstanceNorm(nn.Module):
|
1037 |
+
def __init__(self, epsilon=1e-8):
|
1038 |
+
"""
|
1039 |
+
@notice: avoid in-place ops.
|
1040 |
+
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
|
1041 |
+
"""
|
1042 |
+
super(InstanceNorm, self).__init__()
|
1043 |
+
self.epsilon = epsilon
|
1044 |
+
|
1045 |
+
def forward(self, x):
|
1046 |
+
x = x - torch.mean(x, (2, 3), True)
|
1047 |
+
tmp = torch.mul(x, x) # or x ** 2
|
1048 |
+
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
|
1049 |
+
return x * tmp
|
1050 |
+
|
1051 |
+
|
1052 |
+
class LayerEpilogue(nn.Module):
|
1053 |
+
def __init__(self, channels, dlatent_size, use_wscale, use_noise,
|
1054 |
+
use_pixel_norm, use_instance_norm, use_styles, nl_layer=None):
|
1055 |
+
super(LayerEpilogue, self).__init__()
|
1056 |
+
self.use_noise = use_noise
|
1057 |
+
if use_noise:
|
1058 |
+
self.noise = ApplyNoise(channels)
|
1059 |
+
self.act = nn.LeakyReLU(negative_slope=0.2)
|
1060 |
+
|
1061 |
+
if use_pixel_norm:
|
1062 |
+
self.pixel_norm = PixelNorm()
|
1063 |
+
else:
|
1064 |
+
self.pixel_norm = None
|
1065 |
+
|
1066 |
+
if use_instance_norm:
|
1067 |
+
self.instance_norm = InstanceNorm()
|
1068 |
+
else:
|
1069 |
+
self.instance_norm = None
|
1070 |
+
|
1071 |
+
if use_styles:
|
1072 |
+
self.style_mod = ApplyStyle(dlatent_size, channels, use_wscale=use_wscale, nl_layer=nl_layer)
|
1073 |
+
else:
|
1074 |
+
self.style_mod = None
|
1075 |
+
|
1076 |
+
def forward(self, x, noise, dlatents_in_slice=None):
|
1077 |
+
# if noise is not None:
|
1078 |
+
if self.use_noise:
|
1079 |
+
x = self.noise(x, noise)
|
1080 |
+
x = self.act(x)
|
1081 |
+
if self.pixel_norm is not None:
|
1082 |
+
x = self.pixel_norm(x)
|
1083 |
+
if self.instance_norm is not None:
|
1084 |
+
x = self.instance_norm(x)
|
1085 |
+
if self.style_mod is not None:
|
1086 |
+
x = self.style_mod(x, dlatents_in_slice)
|
1087 |
+
|
1088 |
+
return x
|
1089 |
+
|
1090 |
+
class G_mapping(nn.Module):
|
1091 |
+
def __init__(self,
|
1092 |
+
mapping_fmaps=512,
|
1093 |
+
dlatent_size=512,
|
1094 |
+
resolution=512,
|
1095 |
+
normalize_latents=True, # Normalize latent vectors (Z) before feeding them to the mapping layers?
|
1096 |
+
use_wscale=True, # Enable equalized learning rate?
|
1097 |
+
lrmul=0.01, # Learning rate multiplier for the mapping layers.
|
1098 |
+
gain=2**(0.5), # original gain in tensorflow.
|
1099 |
+
nl_layer=None
|
1100 |
+
):
|
1101 |
+
super(G_mapping, self).__init__()
|
1102 |
+
self.mapping_fmaps = mapping_fmaps
|
1103 |
+
func = [
|
1104 |
+
nn.Linear(self.mapping_fmaps, dlatent_size)
|
1105 |
+
]
|
1106 |
+
if nl_layer:
|
1107 |
+
func += [nl_layer()]
|
1108 |
+
|
1109 |
+
for j in range(0,4):
|
1110 |
+
func += [
|
1111 |
+
nn.Linear(dlatent_size, dlatent_size)
|
1112 |
+
]
|
1113 |
+
if nl_layer:
|
1114 |
+
func += [nl_layer()]
|
1115 |
+
|
1116 |
+
self.func = nn.Sequential(*func)
|
1117 |
+
#FC(self.mapping_fmaps, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale),
|
1118 |
+
#FC(dlatent_size, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale),
|
1119 |
+
|
1120 |
+
self.normalize_latents = normalize_latents
|
1121 |
+
self.resolution_log2 = int(np.log2(resolution))
|
1122 |
+
self.num_layers = self.resolution_log2 * 2 - 2
|
1123 |
+
self.pixel_norm = PixelNorm()
|
1124 |
+
# - 2 means we start from feature map with height and width equals 4.
|
1125 |
+
# as this example, we get num_layers = 18.
|
1126 |
+
|
1127 |
+
def forward(self, x):
|
1128 |
+
if self.normalize_latents:
|
1129 |
+
x = self.pixel_norm(x)
|
1130 |
+
out = self.func(x)
|
1131 |
+
return out, self.num_layers
|
1132 |
+
|
1133 |
+
class UnetBlock_with_z(nn.Module):
|
1134 |
+
def __init__(self, input_nc, outer_nc, inner_nc, nz=0,
|
1135 |
+
submodule=None, outermost=False, innermost=False,
|
1136 |
+
norm_layer=None, nl_layer=None, use_dropout=False,
|
1137 |
+
upsample='basic', padding_type='replicate'):
|
1138 |
+
super(UnetBlock_with_z, self).__init__()
|
1139 |
+
p = 0
|
1140 |
+
downconv = []
|
1141 |
+
if padding_type == 'reflect':
|
1142 |
+
downconv += [nn.ReflectionPad2d(1)]
|
1143 |
+
elif padding_type == 'replicate':
|
1144 |
+
downconv += [nn.ReplicationPad2d(1)]
|
1145 |
+
elif padding_type == 'zero':
|
1146 |
+
p = 1
|
1147 |
+
else:
|
1148 |
+
raise NotImplementedError(
|
1149 |
+
'padding [%s] is not implemented' % padding_type)
|
1150 |
+
|
1151 |
+
self.outermost = outermost
|
1152 |
+
self.innermost = innermost
|
1153 |
+
self.nz = nz
|
1154 |
+
|
1155 |
+
# input_nc = input_nc + nz
|
1156 |
+
downconv += [spectral_norm(nn.Conv2d(input_nc, inner_nc,
|
1157 |
+
kernel_size=3, stride=2, padding=p))]
|
1158 |
+
# downsample is different from upsample
|
1159 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
1160 |
+
downnorm = norm_layer(inner_nc) if norm_layer is not None else None
|
1161 |
+
uprelu = nl_layer()
|
1162 |
+
uprelu2 = nl_layer()
|
1163 |
+
uppad = nn.ReplicationPad2d(1)
|
1164 |
+
upnorm = norm_layer(outer_nc) if norm_layer is not None else None
|
1165 |
+
upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
|
1166 |
+
|
1167 |
+
use_styles=False
|
1168 |
+
uprelu = nl_layer()
|
1169 |
+
if self.nz >0:
|
1170 |
+
use_styles=True
|
1171 |
+
|
1172 |
+
if outermost:
|
1173 |
+
self.adaIn = LayerEpilogue(inner_nc, self.nz, use_wscale=True, use_noise=False,
|
1174 |
+
use_pixel_norm=True, use_instance_norm=True, use_styles=use_styles, nl_layer=nl_layer)
|
1175 |
+
upconv = upsampleLayer(
|
1176 |
+
inner_nc , outer_nc, upsample=upsample, padding_type=padding_type)
|
1177 |
+
upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
|
1178 |
+
down = downconv
|
1179 |
+
up = [uprelu] + upconv
|
1180 |
+
if upnorm is not None:
|
1181 |
+
up += [upnorm]
|
1182 |
+
up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
|
1183 |
+
elif innermost:
|
1184 |
+
self.adaIn = LayerEpilogue(inner_nc, self.nz, use_wscale=True, use_noise=True,
|
1185 |
+
use_pixel_norm=True, use_instance_norm=True, use_styles=use_styles, nl_layer=nl_layer)
|
1186 |
+
upconv = upsampleLayer(
|
1187 |
+
inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
|
1188 |
+
upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
|
1189 |
+
down = [downrelu] + downconv
|
1190 |
+
up = [uprelu] + upconv
|
1191 |
+
if norm_layer is not None:
|
1192 |
+
up += [norm_layer(outer_nc)]
|
1193 |
+
up += [uprelu2, uppad, upconv2]
|
1194 |
+
if upnorm2 is not None:
|
1195 |
+
up += [upnorm2]
|
1196 |
+
else:
|
1197 |
+
self.adaIn = LayerEpilogue(inner_nc, self.nz, use_wscale=True, use_noise=False,
|
1198 |
+
use_pixel_norm=True, use_instance_norm=True, use_styles=use_styles, nl_layer=nl_layer)
|
1199 |
+
upconv = upsampleLayer(
|
1200 |
+
inner_nc , outer_nc, upsample=upsample, padding_type=padding_type)
|
1201 |
+
upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
|
1202 |
+
down = [downrelu] + downconv
|
1203 |
+
if norm_layer is not None:
|
1204 |
+
down += [norm_layer(inner_nc)]
|
1205 |
+
up = [uprelu] + upconv
|
1206 |
+
|
1207 |
+
if norm_layer is not None:
|
1208 |
+
up += [norm_layer(outer_nc)]
|
1209 |
+
up += [uprelu2, uppad, upconv2]
|
1210 |
+
if upnorm2 is not None:
|
1211 |
+
up += [upnorm2]
|
1212 |
+
|
1213 |
+
if use_dropout:
|
1214 |
+
up += [nn.Dropout(0.5)]
|
1215 |
+
self.down = nn.Sequential(*down)
|
1216 |
+
self.submodule = submodule
|
1217 |
+
self.up = nn.Sequential(*up)
|
1218 |
+
|
1219 |
+
|
1220 |
+
def forward(self, x, z, noise):
|
1221 |
+
if self.outermost:
|
1222 |
+
x1 = self.down(x)
|
1223 |
+
x2 = self.submodule(x1, z[:,2:], noise[2:])
|
1224 |
+
return self.up(x2)
|
1225 |
+
|
1226 |
+
elif self.innermost:
|
1227 |
+
x1 = self.down(x)
|
1228 |
+
x_and_z = self.adaIn(x1, noise[0], z[:,0])
|
1229 |
+
x2 = self.up(x_and_z)
|
1230 |
+
x2 = F.interpolate(x2, x.shape[2:])
|
1231 |
+
return x2 + x
|
1232 |
+
|
1233 |
+
else:
|
1234 |
+
x1 = self.down(x)
|
1235 |
+
x2 = self.submodule(x1, z[:,2:], noise[2:])
|
1236 |
+
x_and_z = self.adaIn(x2, noise[0], z[:,0])
|
1237 |
+
return self.up(x_and_z) + x
|
1238 |
+
|
1239 |
+
|
1240 |
+
class E_NLayers(nn.Module):
|
1241 |
+
def __init__(self, input_nc, output_nc=1, ndf=64, n_layers=4,
|
1242 |
+
norm_layer=None, nl_layer=None, vaeLike=False):
|
1243 |
+
super(E_NLayers, self).__init__()
|
1244 |
+
self.vaeLike = vaeLike
|
1245 |
+
|
1246 |
+
kw, padw = 3, 1
|
1247 |
+
sequence = [spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw,
|
1248 |
+
stride=2, padding=padw, padding_mode='replicate')), nl_layer()]
|
1249 |
+
|
1250 |
+
nf_mult = 1
|
1251 |
+
nf_mult_prev = 1
|
1252 |
+
for n in range(1, n_layers):
|
1253 |
+
nf_mult_prev = nf_mult
|
1254 |
+
nf_mult = min(2**n, 8)
|
1255 |
+
sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
1256 |
+
kernel_size=kw, stride=2, padding=padw, padding_mode='replicate'))]
|
1257 |
+
if norm_layer is not None:
|
1258 |
+
sequence += [norm_layer(ndf * nf_mult)]
|
1259 |
+
sequence += [nl_layer()]
|
1260 |
+
sequence += [nn.AdaptiveAvgPool2d(4)]
|
1261 |
+
self.conv = nn.Sequential(*sequence)
|
1262 |
+
self.fc = nn.Sequential(*[spectral_norm(nn.Linear(ndf * nf_mult * 16, output_nc))])
|
1263 |
+
if vaeLike:
|
1264 |
+
self.fcVar = nn.Sequential(*[spectral_norm(nn.Linear(ndf * nf_mult * 16, output_nc))])
|
1265 |
+
|
1266 |
+
def forward(self, x):
|
1267 |
+
x_conv = self.conv(x)
|
1268 |
+
conv_flat = x_conv.view(x.size(0), -1)
|
1269 |
+
output = self.fc(conv_flat)
|
1270 |
+
if self.vaeLike:
|
1271 |
+
outputVar = self.fcVar(conv_flat)
|
1272 |
+
return output, outputVar
|
1273 |
+
return output
|
1274 |
+
|
1275 |
+
class BasicBlock(nn.Module):
|
1276 |
+
def __init__(self, inplanes, outplanes):
|
1277 |
+
super(BasicBlock, self).__init__()
|
1278 |
+
layers = []
|
1279 |
+
norm_layer=get_norm_layer(norm_type='layer') #functools.partial(LayerNorm)
|
1280 |
+
# norm_layer = None
|
1281 |
+
nl_layer=nn.ReLU()
|
1282 |
+
if norm_layer is not None:
|
1283 |
+
layers += [norm_layer(inplanes)]
|
1284 |
+
layers += [nl_layer]
|
1285 |
+
layers += [nn.ReplicationPad2d(1),
|
1286 |
+
nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=1,
|
1287 |
+
padding=0, bias=True)]
|
1288 |
+
self.conv = nn.Sequential(*layers)
|
1289 |
+
|
1290 |
+
def forward(self, x):
|
1291 |
+
return self.conv(x)
|
1292 |
+
|
1293 |
+
|
1294 |
+
def define_SVAE(inc=96, outc=3, outplanes=64, blocks=1, netVAE='SVAE', model_name='', load_ext=True, save_dir='',
|
1295 |
+
init_type="normal", init_gain=0.02, gpu_ids=[]):
|
1296 |
+
if netVAE == 'SVAE':
|
1297 |
+
net = ScreenVAE(inc=inc, outc=outc, outplanes=outplanes, blocks=blocks, save_dir=save_dir,
|
1298 |
+
init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids)
|
1299 |
+
else:
|
1300 |
+
raise NotImplementedError('Encoder model name [%s] is not recognized' % net)
|
1301 |
+
init_net(net, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids)
|
1302 |
+
net.load_networks('latest')
|
1303 |
+
return net
|
1304 |
+
|
1305 |
+
|
1306 |
+
class ScreenVAE(nn.Module):
|
1307 |
+
def __init__(self,inc=1,outc=4, outplanes=64, downs=5, blocks=2,load_ext=True, save_dir='',init_type="normal", init_gain=0.02, gpu_ids=[]):
|
1308 |
+
super(ScreenVAE, self).__init__()
|
1309 |
+
self.inc = inc
|
1310 |
+
self.outc = outc
|
1311 |
+
self.save_dir = save_dir
|
1312 |
+
norm_layer=functools.partial(LayerNormWarpper)
|
1313 |
+
nl_layer=nn.LeakyReLU
|
1314 |
+
|
1315 |
+
self.model_names=['enc','dec']
|
1316 |
+
self.enc=define_C(inc+1, outc*2, 0, 24, netC='resnet_6blocks',
|
1317 |
+
norm='layer', nl='lrelu', use_dropout=True, init_type='kaiming',
|
1318 |
+
gpu_ids=gpu_ids, upsample='bilinear')
|
1319 |
+
self.dec=define_G(outc, inc, 0, 48, netG='unet_128_G',
|
1320 |
+
norm='layer', nl='lrelu', use_dropout=True, init_type='kaiming',
|
1321 |
+
gpu_ids=gpu_ids, where_add='input', upsample='bilinear', use_noise=True)
|
1322 |
+
|
1323 |
+
for param in self.parameters():
|
1324 |
+
param.requires_grad = False
|
1325 |
+
|
1326 |
+
def load_networks(self, epoch):
|
1327 |
+
"""Load all the networks from the disk.
|
1328 |
+
|
1329 |
+
Parameters:
|
1330 |
+
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
1331 |
+
"""
|
1332 |
+
for name in self.model_names:
|
1333 |
+
if isinstance(name, str):
|
1334 |
+
load_filename = '%s_net_%s.pth' % (epoch, name)
|
1335 |
+
load_path = os.path.join(self.save_dir, load_filename)
|
1336 |
+
net = getattr(self, name)
|
1337 |
+
if isinstance(net, torch.nn.DataParallel):
|
1338 |
+
net = net.module
|
1339 |
+
print('loading the model from %s' % load_path)
|
1340 |
+
state_dict = torch.load(
|
1341 |
+
load_path, map_location=lambda storage, loc: storage.cuda())
|
1342 |
+
if hasattr(state_dict, '_metadata'):
|
1343 |
+
del state_dict._metadata
|
1344 |
+
|
1345 |
+
net.load_state_dict(state_dict)
|
1346 |
+
del state_dict
|
1347 |
+
|
1348 |
+
def npad(self, im, pad=128):
|
1349 |
+
h,w = im.shape[-2:]
|
1350 |
+
hp = h //pad*pad+pad
|
1351 |
+
wp = w //pad*pad+pad
|
1352 |
+
return F.pad(im, (0, wp-w, 0, hp-h), mode='replicate')
|
1353 |
+
|
1354 |
+
def forward(self, x, line=None, img_input=True, output_screen_only=True):
|
1355 |
+
if img_input:
|
1356 |
+
if line is None:
|
1357 |
+
line = torch.ones_like(x)
|
1358 |
+
else:
|
1359 |
+
line = torch.sign(line)
|
1360 |
+
x = torch.clamp(x + (1-line),-1,1)
|
1361 |
+
h,w = x.shape[-2:]
|
1362 |
+
input = torch.cat([x, line], 1)
|
1363 |
+
input = self.npad(input)
|
1364 |
+
inter = self.enc(input)[:,:,:h,:w]
|
1365 |
+
scr, logvar = torch.split(inter, (self.outc, self.outc), dim=1)
|
1366 |
+
if output_screen_only:
|
1367 |
+
return scr
|
1368 |
+
recons = self.dec(scr)
|
1369 |
+
return recons, scr, logvar
|
1370 |
+
else:
|
1371 |
+
h,w = x.shape[-2:]
|
1372 |
+
x = self.npad(x)
|
1373 |
+
recons = self.dec(x)[:,:,:h,:w]
|
1374 |
+
recons = (recons+1)*(line+1)/2-1
|
1375 |
+
return torch.clamp(recons,-1,1)
|
BidirectionalTranslation/options/base_options.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from util import util
|
4 |
+
import torch
|
5 |
+
import models
|
6 |
+
import data
|
7 |
+
|
8 |
+
class BaseOptions():
|
9 |
+
def __init__(self):
|
10 |
+
self.initialized = False
|
11 |
+
|
12 |
+
def initialize(self, parser):
|
13 |
+
"""Initialize options used during both training and test time."""
|
14 |
+
# Basic options
|
15 |
+
parser.add_argument('--dataroot', required=False, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
|
16 |
+
parser.add_argument('--batch_size', type=int, default=2, help='input batch size')
|
17 |
+
parser.add_argument('--load_size', type=int, default=512, help='scale images to this size') # Modified default
|
18 |
+
parser.add_argument('--crop_size', type=int, default=1024, help='then crop to this size') # Modified default
|
19 |
+
parser.add_argument('--input_nc', type=int, default=1, help='# of input image channels') # Modified default
|
20 |
+
parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') # Modified default
|
21 |
+
parser.add_argument('--nz', type=int, default=64, help='#latent vector') # Modified default
|
22 |
+
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2, -1 for CPU mode')
|
23 |
+
parser.add_argument('--name', type=str, default='color2manga_cycle_ganstft', help='name of the experiment') # Modified default
|
24 |
+
parser.add_argument('--preprocess', type=str, default='none', help='not implemented') # Modified default
|
25 |
+
parser.add_argument('--dataset_mode', type=str, default='aligned', help='aligned,single')
|
26 |
+
parser.add_argument('--model', type=str, default='cycle_ganstft', help='chooses which model to use')
|
27 |
+
parser.add_argument('--direction', type=str, default='BtoA', help='AtoB or BtoA') # Modified default
|
28 |
+
parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
|
29 |
+
parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
|
30 |
+
parser.add_argument('--local_rank', default=0, type=int, help='# threads for loading data')
|
31 |
+
parser.add_argument('--checkpoints_dir', type=str, default=self.model_global_path+'/ScreenStyle/color2manga/', help='models are saved here') # Modified default
|
32 |
+
parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
|
33 |
+
parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')
|
34 |
+
parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset.')
|
35 |
+
parser.add_argument('--no_flip', action='store_false', help='if specified, do not flip the images for data argumentation') # Modified default
|
36 |
+
|
37 |
+
# Model parameters
|
38 |
+
parser.add_argument('--level', type=int, default=0, help='level to train')
|
39 |
+
parser.add_argument('--num_Ds', type=int, default=2, help='number of Discriminators')
|
40 |
+
parser.add_argument('--netD', type=str, default='basic_256_multi', help='selects model to use for netD')
|
41 |
+
parser.add_argument('--netD2', type=str, default='basic_256_multi', help='selects model to use for netD2')
|
42 |
+
parser.add_argument('--netG', type=str, default='unet_256', help='selects model to use for netG')
|
43 |
+
parser.add_argument('--netC', type=str, default='unet_128', help='selects model to use for netC')
|
44 |
+
parser.add_argument('--netE', type=str, default='conv_256', help='selects model to use for netE')
|
45 |
+
parser.add_argument('--nef', type=int, default=48, help='# of encoder filters in the first conv layer') # Modified default
|
46 |
+
parser.add_argument('--ngf', type=int, default=48, help='# of gen filters in the last conv layer') # Modified default
|
47 |
+
parser.add_argument('--ndf', type=int, default=32, help='# of discrim filters in the first conv layer') # Modified default
|
48 |
+
parser.add_argument('--norm', type=str, default='layer', help='instance normalization or batch normalization')
|
49 |
+
parser.add_argument('--upsample', type=str, default='bilinear', help='basic | bilinear') # Modified default
|
50 |
+
parser.add_argument('--nl', type=str, default='prelu', help='non-linearity activation: relu | lrelu | elu')
|
51 |
+
parser.add_argument('--no_encode', action='store_true', help='if specified, print more debugging information')
|
52 |
+
parser.add_argument('--color2screen', action='store_true', help='continue training: load the latest model including RGB model') # Modified default
|
53 |
+
|
54 |
+
# Extra parameters
|
55 |
+
parser.add_argument('--where_add', type=str, default='all', help='input|all|middle; where to add z in the network G')
|
56 |
+
parser.add_argument('--conditional_D', action='store_true', help='if use conditional GAN for D')
|
57 |
+
parser.add_argument('--init_type', type=str, default='kaiming', help='network initialization [normal | xavier | kaiming | orthogonal]')
|
58 |
+
parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
|
59 |
+
parser.add_argument('--center_crop', action='store_true', help='if apply for center cropping for the test') # Modified default
|
60 |
+
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
|
61 |
+
parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
|
62 |
+
parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
|
63 |
+
|
64 |
+
# Special tasks
|
65 |
+
self.initialized = True
|
66 |
+
return parser
|
67 |
+
|
68 |
+
def gather_options(self):
|
69 |
+
"""Initialize our parser with basic options (only once)."""
|
70 |
+
if not self.initialized:
|
71 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
72 |
+
parser = self.initialize(parser)
|
73 |
+
|
74 |
+
# Get the basic options
|
75 |
+
opt, _ = parser.parse_known_args()
|
76 |
+
|
77 |
+
# Modify model-related parser options
|
78 |
+
model_name = opt.model
|
79 |
+
model_option_setter = models.get_option_setter(model_name)
|
80 |
+
parser = model_option_setter(parser, self.isTrain)
|
81 |
+
opt, _ = parser.parse_known_args() # Parse again with new defaults
|
82 |
+
|
83 |
+
# Modify dataset-related parser options
|
84 |
+
dataset_name = opt.dataset_mode
|
85 |
+
dataset_option_setter = data.get_option_setter(dataset_name)
|
86 |
+
parser = dataset_option_setter(parser, self.isTrain)
|
87 |
+
|
88 |
+
# Save and return the parser
|
89 |
+
self.parser = parser
|
90 |
+
return parser.parse_args()
|
91 |
+
|
92 |
+
def print_options(self, opt):
|
93 |
+
"""Print and save options."""
|
94 |
+
message = ''
|
95 |
+
message += '----------------- Options ---------------\n'
|
96 |
+
for k, v in sorted(vars(opt).items()):
|
97 |
+
comment = ''
|
98 |
+
default = self.parser.get_default(k)
|
99 |
+
if v != default:
|
100 |
+
comment = '\t[default: %s]' % str(default)
|
101 |
+
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
102 |
+
message += '----------------- End -------------------'
|
103 |
+
print(message)
|
104 |
+
|
105 |
+
# Save to the disk
|
106 |
+
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
107 |
+
if not os.path.exists(expr_dir):
|
108 |
+
try:
|
109 |
+
util.mkdirs(expr_dir)
|
110 |
+
except:
|
111 |
+
pass
|
112 |
+
file_name = os.path.join(expr_dir, 'opt.txt')
|
113 |
+
with open(file_name, 'wt') as opt_file:
|
114 |
+
opt_file.write(message)
|
115 |
+
opt_file.write('\n')
|
116 |
+
|
117 |
+
def parse(self, model_global_path):
|
118 |
+
"""Parse options, create checkpoints directory suffix, and set up gpu device."""
|
119 |
+
self.model_global_path = model_global_path
|
120 |
+
opt = self.gather_options()
|
121 |
+
opt.isTrain = self.isTrain # train or test
|
122 |
+
|
123 |
+
|
124 |
+
# Process opt.suffix
|
125 |
+
if opt.suffix:
|
126 |
+
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
|
127 |
+
opt.name = opt.name + suffix
|
128 |
+
|
129 |
+
self.print_options(opt)
|
130 |
+
|
131 |
+
# Set gpu ids
|
132 |
+
str_ids = opt.gpu_ids.split(',')
|
133 |
+
opt.gpu_ids = []
|
134 |
+
for str_id in str_ids:
|
135 |
+
id = int(str_id)
|
136 |
+
if id >= 0:
|
137 |
+
opt.gpu_ids.append(id)
|
138 |
+
if len(opt.gpu_ids) > 0:
|
139 |
+
torch.cuda.set_device(opt.gpu_ids[0])
|
140 |
+
|
141 |
+
self.opt = opt
|
142 |
+
return self.opt
|
BidirectionalTranslation/options/test_options.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_options import BaseOptions
|
2 |
+
|
3 |
+
class TestOptions(BaseOptions):
|
4 |
+
def initialize(self, parser):
|
5 |
+
BaseOptions.initialize(self, parser)
|
6 |
+
|
7 |
+
|
8 |
+
# Additional test-specific arguments
|
9 |
+
parser.add_argument('--results_dir', type=str, default='../results/', help='saves results here.')
|
10 |
+
parser.add_argument('--phase', type=str, default='val', help='train, val, test, etc')
|
11 |
+
parser.add_argument('--num_test', type=int, default=30, help='how many test images to run')
|
12 |
+
parser.add_argument('--n_samples', type=int, default=1, help='#samples')
|
13 |
+
parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio for the results')
|
14 |
+
parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
|
15 |
+
parser.add_argument('--folder', type=str, default='intra', help='saves results here.')
|
16 |
+
parser.add_argument('--sync', action='store_true', help='use the same latent code for different input images')
|
17 |
+
|
18 |
+
self.isTrain = False
|
19 |
+
return parser
|
BidirectionalTranslation/requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch~=1.6.0
|
2 |
+
torchvision~=0.4.0
|
3 |
+
tensorboardx~=1.9
|
4 |
+
scipy==1.1
|
5 |
+
dominate~=2.3.1
|
6 |
+
scikit-image~=0.16.2
|
7 |
+
opencv-python~=3.4.2
|
8 |
+
lpips
|
BidirectionalTranslation/scripts/test_western2manga.sh
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
set -ex
|
2 |
+
# models
|
3 |
+
RESULTS_DIR='./results/test/western2manga'
|
4 |
+
|
5 |
+
# dataset
|
6 |
+
CLASS='color2manga'
|
7 |
+
MODEL='cycle_ganstft'
|
8 |
+
DIRECTION='BtoA' # from domain A to domain B
|
9 |
+
PREPROCESS='none'
|
10 |
+
LOAD_SIZE=512 # scale images to this size
|
11 |
+
CROP_SIZE=1024 # then crop to this size
|
12 |
+
INPUT_NC=1 # number of channels in the input image
|
13 |
+
OUTPUT_NC=3 # number of channels in the input image
|
14 |
+
NGF=48
|
15 |
+
NEF=48
|
16 |
+
NDF=32
|
17 |
+
NZ=64
|
18 |
+
|
19 |
+
# misc
|
20 |
+
GPU_ID=0 # gpu id
|
21 |
+
NUM_TEST=30 # number of input images duirng test
|
22 |
+
NUM_SAMPLES=1 # number of samples per input images
|
23 |
+
NAME=${CLASS}_${MODEL}
|
24 |
+
|
25 |
+
# command
|
26 |
+
CUDA_VISIBLE_DEVICES=${GPU_ID} \
|
27 |
+
python3 ./test.py \
|
28 |
+
--dataroot ./datasets/${CLASS} \
|
29 |
+
--results_dir ${RESULTS_DIR} \
|
30 |
+
--checkpoints_dir ./checkpoints/${CLASS}/ \
|
31 |
+
--name ${NAME} \
|
32 |
+
--model ${MODEL} \
|
33 |
+
--direction ${DIRECTION} \
|
34 |
+
--preprocess ${PREPROCESS} \
|
35 |
+
--load_size ${LOAD_SIZE} \
|
36 |
+
--crop_size ${CROP_SIZE} \
|
37 |
+
--input_nc ${INPUT_NC} \
|
38 |
+
--output_nc ${OUTPUT_NC} \
|
39 |
+
--nz ${NZ} \
|
40 |
+
--netE conv_256 \
|
41 |
+
--num_test ${NUM_TEST} \
|
42 |
+
--n_samples ${NUM_SAMPLES} \
|
43 |
+
--upsample bilinear \
|
44 |
+
--ngf ${NGF} \
|
45 |
+
--nef ${NEF} \
|
46 |
+
--ndf ${NDF} \
|
47 |
+
--center_crop \
|
48 |
+
--color2screen \
|
49 |
+
--no_flip
|
BidirectionalTranslation/test.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from options.test_options import TestOptions
|
3 |
+
from data import create_dataset
|
4 |
+
from models import create_model
|
5 |
+
from util.visualizer import save_images
|
6 |
+
from itertools import islice
|
7 |
+
from util import html
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
seed = 10
|
11 |
+
import torch
|
12 |
+
import numpy as np
|
13 |
+
torch.manual_seed(seed)
|
14 |
+
torch.cuda.manual_seed(seed)
|
15 |
+
np.random.seed(seed)
|
16 |
+
|
17 |
+
# options
|
18 |
+
opt = TestOptions().parse()
|
19 |
+
opt.num_threads = 1 # test code only supports num_threads=1
|
20 |
+
opt.batch_size = 1 # test code only supports batch_size=1
|
21 |
+
opt.serial_batches = True # no shuffle
|
22 |
+
|
23 |
+
model = create_model(opt)
|
24 |
+
model.setup(opt)
|
25 |
+
model.eval()
|
26 |
+
print('Loading model %s' % opt.model)
|
27 |
+
|
28 |
+
testdata = ['manga_paper']
|
29 |
+
# fake_sty = model.get_z_random(1, 64, truncation=True)
|
30 |
+
|
31 |
+
opt.dataset_mode = 'singleSr'
|
32 |
+
for folder in testdata:
|
33 |
+
opt.folder = folder
|
34 |
+
# create dataset
|
35 |
+
dataset = create_dataset(opt)
|
36 |
+
web_dir = os.path.join(opt.results_dir, opt.folder + '_Sr2Co')
|
37 |
+
webpage = html.HTML(web_dir, 'Training = %s, Phase = %s, Class =%s' % (opt.name, opt.phase, opt.name))
|
38 |
+
# fake_sty = model.get_z_random(1, 64, truncation=True)
|
39 |
+
for i, data in enumerate(islice(dataset, opt.num_test)):
|
40 |
+
h = data['h']
|
41 |
+
w = data['w']
|
42 |
+
model.set_input(data)
|
43 |
+
fake_sty = model.get_z_random(1, 64, truncation=True, tvalue=1.25)
|
44 |
+
fake_B, SCR, line = model.forward(AtoB=False, sty=fake_sty)
|
45 |
+
images=[fake_B[:,:,:h,:w]]
|
46 |
+
names=['color']
|
47 |
+
|
48 |
+
img_path = 'input_%3.3d' % i
|
49 |
+
save_images(webpage, images, names, img_path, aspect_ratio=opt.aspect_ratio, width=opt.crop_size)
|
50 |
+
webpage.save()
|
51 |
+
|
52 |
+
testdata = ['western_paper']
|
53 |
+
|
54 |
+
opt.dataset_mode = 'singleCo'
|
55 |
+
for folder in testdata:
|
56 |
+
opt.folder = folder
|
57 |
+
# create dataset
|
58 |
+
dataset = create_dataset(opt)
|
59 |
+
web_dir = os.path.join(opt.results_dir, opt.folder + '_Sr2Co')
|
60 |
+
webpage = html.HTML(web_dir, 'Training = %s, Phase = %s, Class =%s' % (opt.name, opt.phase, opt.name))
|
61 |
+
for i, data in enumerate(islice(dataset, opt.num_test)):
|
62 |
+
h = data['h']
|
63 |
+
w = data['w']
|
64 |
+
model.set_input(data)
|
65 |
+
fake_B, fake_B2, SCR = model.forward(AtoB=True)
|
66 |
+
images=[fake_B2[:,:,:h,:w]]
|
67 |
+
names=['manga']
|
68 |
+
|
69 |
+
img_path = 'input_%3.3d' % i
|
70 |
+
save_images(webpage, images, names, img_path, aspect_ratio=opt.aspect_ratio, width=opt.crop_size)
|
71 |
+
webpage.save()
|
BidirectionalTranslation/util/html.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dominate
|
2 |
+
from dominate.tags import meta, h3, table, tr, td, p, a, img, br
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class HTML:
|
7 |
+
"""This HTML class allows us to save images and write texts into a single HTML file.
|
8 |
+
|
9 |
+
It consists of functions such as <add_header> (add a text header to the HTML file),
|
10 |
+
<add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
|
11 |
+
It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, web_dir, title, refresh=0):
|
15 |
+
"""Initialize the HTML classes
|
16 |
+
|
17 |
+
Parameters:
|
18 |
+
web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
|
19 |
+
title (str) -- the webpage name
|
20 |
+
reflect (int) -- how often the website refresh itself; if 0; no refreshing
|
21 |
+
"""
|
22 |
+
self.title = title
|
23 |
+
self.web_dir = web_dir
|
24 |
+
self.img_dir = os.path.join(self.web_dir, 'images')
|
25 |
+
if not os.path.exists(self.web_dir):
|
26 |
+
os.makedirs(self.web_dir)
|
27 |
+
if not os.path.exists(self.img_dir):
|
28 |
+
os.makedirs(self.img_dir)
|
29 |
+
|
30 |
+
self.doc = dominate.document(title=title)
|
31 |
+
if refresh > 0:
|
32 |
+
with self.doc.head:
|
33 |
+
meta(http_equiv="refresh", content=str(refresh))
|
34 |
+
|
35 |
+
def get_image_dir(self):
|
36 |
+
"""Return the directory that stores images"""
|
37 |
+
return self.img_dir
|
38 |
+
|
39 |
+
def add_header(self, text):
|
40 |
+
"""Insert a header to the HTML file
|
41 |
+
|
42 |
+
Parameters:
|
43 |
+
text (str) -- the header text
|
44 |
+
"""
|
45 |
+
with self.doc:
|
46 |
+
h3(text)
|
47 |
+
|
48 |
+
def add_images(self, ims, txts, links, width=400):
|
49 |
+
"""add images to the HTML file
|
50 |
+
|
51 |
+
Parameters:
|
52 |
+
ims (str list) -- a list of image paths
|
53 |
+
txts (str list) -- a list of image names shown on the website
|
54 |
+
links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
|
55 |
+
"""
|
56 |
+
self.t = table(border=1, style="table-layout: fixed;") # Insert a table
|
57 |
+
self.doc.add(self.t)
|
58 |
+
with self.t:
|
59 |
+
with tr():
|
60 |
+
for im, txt, link in zip(ims, txts, links):
|
61 |
+
with td(style="word-wrap: break-word;", halign="center", valign="top"):
|
62 |
+
with p():
|
63 |
+
with a(href=os.path.join('images', link)):
|
64 |
+
img(style="width:%dpx" % width, src=os.path.join('images', im))
|
65 |
+
br()
|
66 |
+
p(txt)
|
67 |
+
|
68 |
+
def save(self):
|
69 |
+
"""save the current content to the HMTL file"""
|
70 |
+
html_file = '%s/index.html' % self.web_dir
|
71 |
+
f = open(html_file, 'wt')
|
72 |
+
f.write(self.doc.render())
|
73 |
+
f.close()
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == '__main__': # we show an example usage here.
|
77 |
+
html = HTML('web/', 'test_html')
|
78 |
+
html.add_header('hello world')
|
79 |
+
|
80 |
+
ims, txts, links = [], [], []
|
81 |
+
for n in range(4):
|
82 |
+
ims.append('image_%d.png' % n)
|
83 |
+
txts.append('text_%d' % n)
|
84 |
+
links.append('image_%d.png' % n)
|
85 |
+
html.add_images(ims, txts, links)
|
86 |
+
html.save()
|
BidirectionalTranslation/util/util.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import os
|
6 |
+
import pickle
|
7 |
+
|
8 |
+
|
9 |
+
def tensor2im(input_image, imtype=np.uint8):
|
10 |
+
""""Convert a Tensor array into a numpy image array.
|
11 |
+
Parameters:
|
12 |
+
input_image (tensor) -- the input image tensor array
|
13 |
+
imtype (type) -- the desired type of the converted numpy array
|
14 |
+
"""
|
15 |
+
if not isinstance(input_image, np.ndarray):
|
16 |
+
if isinstance(input_image, torch.Tensor): # get the data from a variable
|
17 |
+
image_tensor = input_image.data
|
18 |
+
else:
|
19 |
+
return input_image
|
20 |
+
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
|
21 |
+
if image_numpy.shape[0] == 1: # grayscale to RGB
|
22 |
+
image_numpy = np.tile(image_numpy, (3, 1, 1))
|
23 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
|
24 |
+
else: # if it is a numpy array, do nothing
|
25 |
+
image_numpy = input_image
|
26 |
+
return image_numpy.astype(imtype)
|
27 |
+
|
28 |
+
|
29 |
+
def tensor2vec(vector_tensor):
|
30 |
+
numpy_vec = vector_tensor.data.cpu().numpy()
|
31 |
+
if numpy_vec.ndim == 4:
|
32 |
+
return numpy_vec[:, :, 0, 0]
|
33 |
+
else:
|
34 |
+
return numpy_vec
|
35 |
+
|
36 |
+
|
37 |
+
def pickle_load(file_name):
|
38 |
+
data = None
|
39 |
+
with open(file_name, 'rb') as f:
|
40 |
+
data = pickle.load(f)
|
41 |
+
return data
|
42 |
+
|
43 |
+
|
44 |
+
def pickle_save(file_name, data):
|
45 |
+
with open(file_name, 'wb') as f:
|
46 |
+
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
|
47 |
+
|
48 |
+
|
49 |
+
def diagnose_network(net, name='network'):
|
50 |
+
"""Calculate and print the mean of average absolute(gradients)
|
51 |
+
Parameters:
|
52 |
+
net (torch network) -- Torch network
|
53 |
+
name (str) -- the name of the network
|
54 |
+
"""
|
55 |
+
mean = 0.0
|
56 |
+
count = 0
|
57 |
+
for param in net.parameters():
|
58 |
+
if param.grad is not None:
|
59 |
+
mean += torch.mean(torch.abs(param.grad.data))
|
60 |
+
count += 1
|
61 |
+
if count > 0:
|
62 |
+
mean = mean / count
|
63 |
+
print(name)
|
64 |
+
print(mean)
|
65 |
+
|
66 |
+
|
67 |
+
def interp_z(z0, z1, num_frames, interp_mode='linear'):
|
68 |
+
zs = []
|
69 |
+
if interp_mode == 'linear':
|
70 |
+
for n in range(num_frames):
|
71 |
+
ratio = n / float(num_frames - 1)
|
72 |
+
z_t = (1 - ratio) * z0 + ratio * z1
|
73 |
+
zs.append(z_t[np.newaxis, :])
|
74 |
+
zs = np.concatenate(zs, axis=0).astype(np.float32)
|
75 |
+
|
76 |
+
if interp_mode == 'slerp':
|
77 |
+
z0_n = z0 / (np.linalg.norm(z0) + 1e-10)
|
78 |
+
z1_n = z1 / (np.linalg.norm(z1) + 1e-10)
|
79 |
+
omega = np.arccos(np.dot(z0_n, z1_n))
|
80 |
+
sin_omega = np.sin(omega)
|
81 |
+
if sin_omega < 1e-10 and sin_omega > -1e-10:
|
82 |
+
zs = interp_z(z0, z1, num_frames, interp_mode='linear')
|
83 |
+
else:
|
84 |
+
for n in range(num_frames):
|
85 |
+
ratio = n / float(num_frames - 1)
|
86 |
+
z_t = np.sin((1 - ratio) * omega) / sin_omega * z0 + np.sin(ratio * omega) / sin_omega * z1
|
87 |
+
zs.append(z_t[np.newaxis, :])
|
88 |
+
zs = np.concatenate(zs, axis=0).astype(np.float32)
|
89 |
+
|
90 |
+
return zs
|
91 |
+
|
92 |
+
|
93 |
+
def save_image(image_numpy, image_path):
|
94 |
+
"""Save a numpy image to the disk
|
95 |
+
Parameters:
|
96 |
+
image_numpy (numpy array) -- input numpy array
|
97 |
+
image_path (str) -- the path of the image
|
98 |
+
"""
|
99 |
+
image_pil = Image.fromarray(image_numpy)
|
100 |
+
image_pil.save(image_path)
|
101 |
+
|
102 |
+
|
103 |
+
def print_numpy(x, val=True, shp=False):
|
104 |
+
"""Print the mean, min, max, median, std, and size of a numpy array
|
105 |
+
Parameters:
|
106 |
+
val (bool) -- if print the values of the numpy array
|
107 |
+
shp (bool) -- if print the shape of the numpy array
|
108 |
+
"""
|
109 |
+
x = x.astype(np.float64)
|
110 |
+
if shp:
|
111 |
+
print('shape,', x.shape)
|
112 |
+
if val:
|
113 |
+
x = x.flatten()
|
114 |
+
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
|
115 |
+
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
|
116 |
+
|
117 |
+
|
118 |
+
def mkdirs(paths):
|
119 |
+
"""create empty directories if they don't exist
|
120 |
+
Parameters:
|
121 |
+
paths (str list) -- a list of directory paths
|
122 |
+
"""
|
123 |
+
if isinstance(paths, list) and not isinstance(paths, str):
|
124 |
+
for path in paths:
|
125 |
+
mkdir(path)
|
126 |
+
else:
|
127 |
+
mkdir(paths)
|
128 |
+
|
129 |
+
|
130 |
+
def mkdir(path):
|
131 |
+
"""create a single empty directory if it didn't exist
|
132 |
+
Parameters:
|
133 |
+
path (str) -- a single directory path
|
134 |
+
"""
|
135 |
+
if not os.path.exists(path):
|
136 |
+
os.makedirs(path, exist_ok=True)
|
BidirectionalTranslation/util/visualizer.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import ntpath
|
5 |
+
import time
|
6 |
+
from . import util
|
7 |
+
from . import html
|
8 |
+
from subprocess import Popen, PIPE
|
9 |
+
import cv2
|
10 |
+
|
11 |
+
|
12 |
+
# if sys.version_info[0] == 2:
|
13 |
+
# VisdomExceptionBase = Exception
|
14 |
+
# else:
|
15 |
+
# VisdomExceptionBase = ConnectionError
|
16 |
+
|
17 |
+
|
18 |
+
def save_images(webpage, images, names, image_path, aspect_ratio=1.0, width=256):
|
19 |
+
"""Save images to the disk.
|
20 |
+
Parameters:
|
21 |
+
webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
|
22 |
+
images (numpy array list) -- a list of numpy array that stores images
|
23 |
+
names (str list) -- a str list stores the names of the images above
|
24 |
+
image_path (str) -- the string is used to create image paths
|
25 |
+
aspect_ratio (float) -- the aspect ratio of saved images
|
26 |
+
width (int) -- the images will be resized to width x width
|
27 |
+
This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
|
28 |
+
"""
|
29 |
+
image_dir = webpage.get_image_dir()
|
30 |
+
name = ntpath.basename(image_path)
|
31 |
+
|
32 |
+
webpage.add_header(name)
|
33 |
+
ims, txts, links = [], [], []
|
34 |
+
|
35 |
+
for label, im_data in zip(names, images):
|
36 |
+
im = util.tensor2im(im_data)
|
37 |
+
image_name = '%s_%s.jpg' % (name, label)
|
38 |
+
save_path = os.path.join(image_dir, image_name)
|
39 |
+
h, w, _ = im.shape
|
40 |
+
if aspect_ratio > 1.0:
|
41 |
+
im = cv2.resize(im, (h, int(w * aspect_ratio)), interpolation=cv2.INTER_CUBIC)
|
42 |
+
if aspect_ratio < 1.0:
|
43 |
+
im = cv2.resize(im, (int(h / aspect_ratio), w), interpolation=cv2.INTER_CUBIC)
|
44 |
+
util.save_image(im, save_path)
|
45 |
+
|
46 |
+
ims.append(image_name)
|
47 |
+
txts.append(label)
|
48 |
+
links.append(image_name)
|
49 |
+
webpage.add_images(ims, txts, links, width=width)
|
50 |
+
|
51 |
+
|
52 |
+
class Visualizer():
|
53 |
+
"""This class includes several functions that can display/save images and print/save logging information.
|
54 |
+
It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(self, opt):
|
58 |
+
"""Initialize the Visualizer class
|
59 |
+
Parameters:
|
60 |
+
opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
61 |
+
Step 1: Cache the training/test options
|
62 |
+
Step 2: connect to a visdom server
|
63 |
+
Step 3: create an HTML object for saveing HTML filters
|
64 |
+
Step 4: create a logging file to store training losses
|
65 |
+
"""
|
66 |
+
self.opt = opt # cache the option
|
67 |
+
self.display_id = opt.display_id
|
68 |
+
self.use_html = opt.isTrain and not opt.no_html
|
69 |
+
self.win_size = opt.display_winsize
|
70 |
+
self.name = opt.name
|
71 |
+
self.port = opt.display_port
|
72 |
+
self.saved = False
|
73 |
+
# if self.display_id > 0: # connect to a visdom server given <display_port> and <display_server>
|
74 |
+
# import visdom
|
75 |
+
# self.ncols = opt.display_ncols
|
76 |
+
# self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
|
77 |
+
# if not self.vis.check_connection():
|
78 |
+
# self.create_visdom_connections()
|
79 |
+
if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
|
80 |
+
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
|
81 |
+
self.img_dir = os.path.join(self.web_dir, 'images')
|
82 |
+
print('create web directory %s...' % self.web_dir)
|
83 |
+
util.mkdirs([self.web_dir, self.img_dir])
|
84 |
+
# create a logging file to store training losses
|
85 |
+
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
|
86 |
+
with open(self.log_name, "a") as log_file:
|
87 |
+
now = time.strftime("%c")
|
88 |
+
log_file.write('================ Training Loss (%s) ================\n' % now)
|
89 |
+
|
90 |
+
def reset(self):
|
91 |
+
"""Reset the self.saved status"""
|
92 |
+
self.saved = False
|
93 |
+
|
94 |
+
def create_visdom_connections(self):
|
95 |
+
"""If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
|
96 |
+
cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
|
97 |
+
print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
|
98 |
+
print('Command: %s' % cmd)
|
99 |
+
Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
|
100 |
+
|
101 |
+
def display_current_results(self, visuals, epoch, save_result):
|
102 |
+
"""Display current results on visdom; save current results to an HTML file.
|
103 |
+
Parameters:
|
104 |
+
visuals (OrderedDict) - - dictionary of images to display or save
|
105 |
+
epoch (int) - - the current epoch
|
106 |
+
save_result (bool) - - if save the current results to an HTML file
|
107 |
+
"""
|
108 |
+
# if self.display_id > 0: # show images in the browser using visdom
|
109 |
+
# ncols = self.ncols
|
110 |
+
# if ncols > 0: # show all the images in one visdom panel
|
111 |
+
# ncols = min(ncols, len(visuals))
|
112 |
+
# h, w = next(iter(visuals.values())).shape[:2]
|
113 |
+
# table_css = """<style>
|
114 |
+
# table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
|
115 |
+
# table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
|
116 |
+
# </style>""" % (w, h) # create a table css
|
117 |
+
# # create a table of images.
|
118 |
+
# title = self.name
|
119 |
+
# label_html = ''
|
120 |
+
# label_html_row = ''
|
121 |
+
# images = []
|
122 |
+
# idx = 0
|
123 |
+
# for label, image in visuals.items():
|
124 |
+
# image_numpy = util.tensor2im(image)
|
125 |
+
# label_html_row += '<td>%s</td>' % label
|
126 |
+
# images.append(image_numpy.transpose([2, 0, 1]))
|
127 |
+
# idx += 1
|
128 |
+
# if idx % ncols == 0:
|
129 |
+
# label_html += '<tr>%s</tr>' % label_html_row
|
130 |
+
# label_html_row = ''
|
131 |
+
# white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
|
132 |
+
# while idx % ncols != 0:
|
133 |
+
# images.append(white_image)
|
134 |
+
# label_html_row += '<td></td>'
|
135 |
+
# idx += 1
|
136 |
+
# if label_html_row != '':
|
137 |
+
# label_html += '<tr>%s</tr>' % label_html_row
|
138 |
+
# try:
|
139 |
+
# self.vis.images(images, nrow=ncols, win=self.display_id + 1,
|
140 |
+
# padding=2, opts=dict(title=title + ' images'))
|
141 |
+
# label_html = '<table>%s</table>' % label_html
|
142 |
+
# self.vis.text(table_css + label_html, win=self.display_id + 2,
|
143 |
+
# opts=dict(title=title + ' labels'))
|
144 |
+
# except VisdomExceptionBase:
|
145 |
+
# self.create_visdom_connections()
|
146 |
+
|
147 |
+
# else: # show each image in a separate visdom panel;
|
148 |
+
# idx = 1
|
149 |
+
# try:
|
150 |
+
# for label, image in visuals.items():
|
151 |
+
# image_numpy = util.tensor2im(image)
|
152 |
+
# self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
|
153 |
+
# win=self.display_id + idx)
|
154 |
+
# idx += 1
|
155 |
+
# except VisdomExceptionBase:
|
156 |
+
# self.create_visdom_connections()
|
157 |
+
|
158 |
+
if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
|
159 |
+
self.saved = True
|
160 |
+
# save images to the disk
|
161 |
+
for label, image in visuals.items():
|
162 |
+
image_numpy = util.tensor2im(image)
|
163 |
+
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
|
164 |
+
util.save_image(image_numpy, img_path)
|
165 |
+
|
166 |
+
# update website
|
167 |
+
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
|
168 |
+
for n in range(epoch, 0, -1):
|
169 |
+
webpage.add_header('epoch [%d]' % n)
|
170 |
+
ims, txts, links = [], [], []
|
171 |
+
|
172 |
+
for label, image_numpy in visuals.items():
|
173 |
+
image_numpy = util.tensor2im(image)
|
174 |
+
img_path = 'epoch%.3d_%s.png' % (n, label)
|
175 |
+
ims.append(img_path)
|
176 |
+
txts.append(label)
|
177 |
+
links.append(img_path)
|
178 |
+
webpage.add_images(ims, txts, links, width=self.win_size)
|
179 |
+
webpage.save()
|
180 |
+
|
181 |
+
def plot_current_losses(self, epoch, counter_ratio, losses):
|
182 |
+
"""display the current losses on visdom display: dictionary of error labels and values
|
183 |
+
Parameters:
|
184 |
+
epoch (int) -- current epoch
|
185 |
+
counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
|
186 |
+
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
|
187 |
+
"""
|
188 |
+
if not hasattr(self, 'plot_data'):
|
189 |
+
self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
|
190 |
+
self.plot_data['X'].append(epoch + counter_ratio)
|
191 |
+
self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
|
192 |
+
# try:
|
193 |
+
# self.vis.line(
|
194 |
+
# X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
|
195 |
+
# Y=np.array(self.plot_data['Y']),
|
196 |
+
# opts={
|
197 |
+
# 'title': self.name + ' loss over time',
|
198 |
+
# 'legend': self.plot_data['legend'],
|
199 |
+
# 'xlabel': 'epoch',
|
200 |
+
# 'ylabel': 'loss'},
|
201 |
+
# win=self.display_id)
|
202 |
+
# except VisdomExceptionBase:
|
203 |
+
# self.create_visdom_connections()
|
204 |
+
|
205 |
+
# losses: same format as |losses| of plot_current_losses
|
206 |
+
def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
|
207 |
+
"""print current losses on console; also save the losses to the disk
|
208 |
+
Parameters:
|
209 |
+
epoch (int) -- current epoch
|
210 |
+
iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
|
211 |
+
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
|
212 |
+
t_comp (float) -- computational time per data point (normalized by batch_size)
|
213 |
+
t_data (float) -- data loading time per data point (normalized by batch_size)
|
214 |
+
"""
|
215 |
+
message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
|
216 |
+
for k, v in losses.items():
|
217 |
+
message += '%s: %.3f ' % (k, v)
|
218 |
+
|
219 |
+
print(message) # print the message
|
220 |
+
with open(self.log_name, "a") as log_file:
|
221 |
+
log_file.write('%s\n' % message) # save the message
|
app.py
ADDED
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import gc
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
import shutil
|
9 |
+
import sys
|
10 |
+
import time
|
11 |
+
import itertools
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
import cv2
|
15 |
+
import numpy as np
|
16 |
+
from PIL import Image, ImageDraw
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
import torch.utils.checkpoint
|
20 |
+
from torch.utils.data import Dataset
|
21 |
+
from torchvision import transforms
|
22 |
+
from tqdm.auto import tqdm
|
23 |
+
|
24 |
+
import accelerate
|
25 |
+
from accelerate import Accelerator
|
26 |
+
from accelerate.logging import get_logger
|
27 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
28 |
+
|
29 |
+
from datasets import load_dataset
|
30 |
+
from huggingface_hub import create_repo, upload_folder
|
31 |
+
from packaging import version
|
32 |
+
from safetensors.torch import load_model
|
33 |
+
from peft import LoraConfig
|
34 |
+
import gradio as gr
|
35 |
+
import pandas as pd
|
36 |
+
|
37 |
+
import transformers
|
38 |
+
from transformers import (
|
39 |
+
AutoTokenizer,
|
40 |
+
PretrainedConfig,
|
41 |
+
CLIPVisionModelWithProjection,
|
42 |
+
CLIPImageProcessor,
|
43 |
+
CLIPProcessor,
|
44 |
+
)
|
45 |
+
|
46 |
+
import diffusers
|
47 |
+
from diffusers import (
|
48 |
+
AutoencoderKL,
|
49 |
+
DDPMScheduler,
|
50 |
+
ColorGuiderPixArtModel,
|
51 |
+
ColorGuiderSDModel,
|
52 |
+
UNet2DConditionModel,
|
53 |
+
PixArtTransformer2DModel,
|
54 |
+
ColorFlowPixArtAlphaPipeline,
|
55 |
+
ColorFlowSDPipeline,
|
56 |
+
UniPCMultistepScheduler,
|
57 |
+
)
|
58 |
+
from util_colorflow.utils import *
|
59 |
+
|
60 |
+
sys.path.append('./BidirectionalTranslation')
|
61 |
+
from options.test_options import TestOptions
|
62 |
+
from models import create_model
|
63 |
+
from util import util
|
64 |
+
|
65 |
+
from huggingface_hub import snapshot_download
|
66 |
+
|
67 |
+
model_global_path = snapshot_download(repo_id="JunhaoZhuang/ColorFlow", cache_dir='./colorflow/')
|
68 |
+
print(model_global_path)
|
69 |
+
|
70 |
+
|
71 |
+
transform = transforms.Compose([
|
72 |
+
transforms.ToTensor(),
|
73 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
74 |
+
])
|
75 |
+
weight_dtype = torch.float16
|
76 |
+
|
77 |
+
# line model
|
78 |
+
line_model_path = model_global_path + '/LE/erika.pth'
|
79 |
+
line_model = res_skip()
|
80 |
+
line_model.load_state_dict(torch.load(line_model_path))
|
81 |
+
line_model.eval()
|
82 |
+
line_model.cuda()
|
83 |
+
|
84 |
+
# screen model
|
85 |
+
global opt
|
86 |
+
|
87 |
+
opt = TestOptions().parse(model_global_path)
|
88 |
+
ScreenModel = create_model(opt, model_global_path)
|
89 |
+
ScreenModel.setup(opt)
|
90 |
+
ScreenModel.eval()
|
91 |
+
|
92 |
+
image_processor = CLIPImageProcessor()
|
93 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(model_global_path + '/image_encoder/').to('cuda')
|
94 |
+
|
95 |
+
|
96 |
+
examples = [
|
97 |
+
[
|
98 |
+
"./assets/example_5/input.png",
|
99 |
+
["./assets/example_5/ref1.png", "./assets/example_5/ref2.png", "./assets/example_5/ref3.png"],
|
100 |
+
"GrayImage(ScreenStyle)",
|
101 |
+
"800x512",
|
102 |
+
0,
|
103 |
+
10
|
104 |
+
],
|
105 |
+
[
|
106 |
+
"./assets/example_4/input.jpg",
|
107 |
+
["./assets/example_4/ref1.jpg", "./assets/example_4/ref2.jpg", "./assets/example_4/ref3.jpg"],
|
108 |
+
"GrayImage(ScreenStyle)",
|
109 |
+
"640x640",
|
110 |
+
0,
|
111 |
+
10
|
112 |
+
],
|
113 |
+
[
|
114 |
+
"./assets/example_3/input.png",
|
115 |
+
["./assets/example_3/ref1.png", "./assets/example_3/ref2.png", "./assets/example_3/ref3.png"],
|
116 |
+
"GrayImage(ScreenStyle)",
|
117 |
+
"800x512",
|
118 |
+
0,
|
119 |
+
10
|
120 |
+
],
|
121 |
+
[
|
122 |
+
"./assets/example_2/input.png",
|
123 |
+
["./assets/example_2/ref1.png", "./assets/example_2/ref2.png", "./assets/example_2/ref3.png"],
|
124 |
+
"GrayImage(ScreenStyle)",
|
125 |
+
"800x512",
|
126 |
+
0,
|
127 |
+
10
|
128 |
+
],
|
129 |
+
[
|
130 |
+
"./assets/example_1/input.jpg",
|
131 |
+
["./assets/example_1/ref1.jpg", "./assets/example_1/ref2.jpg", "./assets/example_1/ref3.jpg"],
|
132 |
+
"Sketch",
|
133 |
+
"640x640",
|
134 |
+
0,
|
135 |
+
10
|
136 |
+
],
|
137 |
+
[
|
138 |
+
"./assets/example_0/input.jpg",
|
139 |
+
["./assets/example_0/ref1.jpg"],
|
140 |
+
"Sketch",
|
141 |
+
"640x640",
|
142 |
+
0,
|
143 |
+
10
|
144 |
+
],
|
145 |
+
]
|
146 |
+
|
147 |
+
global pipeline
|
148 |
+
global MultiResNetModel
|
149 |
+
|
150 |
+
def load_ckpt(input_style):
|
151 |
+
global pipeline
|
152 |
+
global MultiResNetModel
|
153 |
+
if input_style == "Sketch":
|
154 |
+
ckpt_path = model_global_path + '/sketch/'
|
155 |
+
rank = 128
|
156 |
+
pretrained_model_name_or_path = 'PixArt-alpha/PixArt-XL-2-1024-MS'
|
157 |
+
transformer = PixArtTransformer2DModel.from_pretrained(
|
158 |
+
pretrained_model_name_or_path, subfolder="transformer", revision=None, variant=None
|
159 |
+
)
|
160 |
+
pixart_config = get_pixart_config()
|
161 |
+
|
162 |
+
ColorGuider = ColorGuiderPixArtModel.from_pretrained(ckpt_path)
|
163 |
+
|
164 |
+
transformer_lora_config = LoraConfig(
|
165 |
+
r=rank,
|
166 |
+
lora_alpha=rank,
|
167 |
+
init_lora_weights="gaussian",
|
168 |
+
target_modules=["to_k", "to_q", "to_v", "to_out.0", "proj_in", "proj_out", "ff.net.0.proj", "ff.net.2", "proj", "linear", "linear_1", "linear_2"]
|
169 |
+
)
|
170 |
+
transformer.add_adapter(transformer_lora_config)
|
171 |
+
ckpt_key_t = torch.load(ckpt_path + 'transformer_lora.bin', map_location='cpu')
|
172 |
+
transformer.load_state_dict(ckpt_key_t, strict=False)
|
173 |
+
|
174 |
+
transformer.to('cuda', dtype=weight_dtype)
|
175 |
+
ColorGuider.to('cuda', dtype=weight_dtype)
|
176 |
+
|
177 |
+
pipeline = ColorFlowPixArtAlphaPipeline.from_pretrained(
|
178 |
+
pretrained_model_name_or_path,
|
179 |
+
transformer=transformer,
|
180 |
+
colorguider=ColorGuider,
|
181 |
+
safety_checker=None,
|
182 |
+
revision=None,
|
183 |
+
variant=None,
|
184 |
+
torch_dtype=weight_dtype,
|
185 |
+
)
|
186 |
+
pipeline = pipeline.to("cuda")
|
187 |
+
block_out_channels = [128, 128, 256, 512, 512]
|
188 |
+
|
189 |
+
MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
|
190 |
+
MultiResNetModel.load_state_dict(torch.load(ckpt_path + 'MultiResNetModel.bin', map_location='cpu'), strict=False)
|
191 |
+
MultiResNetModel.to('cuda', dtype=weight_dtype)
|
192 |
+
|
193 |
+
elif input_style == "GrayImage(ScreenStyle)":
|
194 |
+
ckpt_path = model_global_path + '/GraySD/'
|
195 |
+
rank = 64
|
196 |
+
pretrained_model_name_or_path = 'stable-diffusion-v1-5/stable-diffusion-v1-5'
|
197 |
+
unet = UNet2DConditionModel.from_pretrained(
|
198 |
+
pretrained_model_name_or_path, subfolder="unet", revision=None, variant=None
|
199 |
+
)
|
200 |
+
ColorGuider = ColorGuiderSDModel.from_pretrained(ckpt_path)
|
201 |
+
ColorGuider.to('cuda', dtype=weight_dtype)
|
202 |
+
unet.to('cuda', dtype=weight_dtype)
|
203 |
+
|
204 |
+
pipeline = ColorFlowSDPipeline.from_pretrained(
|
205 |
+
pretrained_model_name_or_path,
|
206 |
+
unet=unet,
|
207 |
+
colorguider=ColorGuider,
|
208 |
+
safety_checker=None,
|
209 |
+
revision=None,
|
210 |
+
variant=None,
|
211 |
+
torch_dtype=weight_dtype,
|
212 |
+
)
|
213 |
+
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
|
214 |
+
unet_lora_config = LoraConfig(
|
215 |
+
r=rank,
|
216 |
+
lora_alpha=rank,
|
217 |
+
init_lora_weights="gaussian",
|
218 |
+
target_modules=["to_k", "to_q", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"],#ff.net.0.proj ff.net.2
|
219 |
+
)
|
220 |
+
pipeline.unet.add_adapter(unet_lora_config)
|
221 |
+
pipeline.unet.load_state_dict(torch.load(ckpt_path + 'unet_lora.bin', map_location='cpu'), strict=False)
|
222 |
+
pipeline = pipeline.to("cuda")
|
223 |
+
block_out_channels = [128, 128, 256, 512, 512]
|
224 |
+
|
225 |
+
MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
|
226 |
+
MultiResNetModel.load_state_dict(torch.load(ckpt_path + 'MultiResNetModel.bin', map_location='cpu'), strict=False)
|
227 |
+
MultiResNetModel.to('cuda', dtype=weight_dtype)
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
|
232 |
+
|
233 |
+
global cur_input_style
|
234 |
+
cur_input_style = "Sketch"
|
235 |
+
load_ckpt(cur_input_style)
|
236 |
+
cur_input_style = "GrayImage(ScreenStyle)"
|
237 |
+
load_ckpt(cur_input_style)
|
238 |
+
|
239 |
+
|
240 |
+
def fix_random_seeds(seed):
|
241 |
+
random.seed(seed)
|
242 |
+
np.random.seed(seed)
|
243 |
+
torch.manual_seed(seed)
|
244 |
+
if torch.cuda.is_available():
|
245 |
+
torch.cuda.manual_seed(seed)
|
246 |
+
torch.cuda.manual_seed_all(seed)
|
247 |
+
|
248 |
+
def process_multi_images(files):
|
249 |
+
images = [Image.open(file.name) for file in files]
|
250 |
+
imgs = []
|
251 |
+
for i, img in enumerate(images):
|
252 |
+
imgs.append(img)
|
253 |
+
return imgs
|
254 |
+
|
255 |
+
def extract_lines(image):
|
256 |
+
src = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
|
257 |
+
|
258 |
+
rows = int(np.ceil(src.shape[0] / 16)) * 16
|
259 |
+
cols = int(np.ceil(src.shape[1] / 16)) * 16
|
260 |
+
|
261 |
+
patch = np.ones((1, 1, rows, cols), dtype="float32")
|
262 |
+
patch[0, 0, 0:src.shape[0], 0:src.shape[1]] = src
|
263 |
+
|
264 |
+
tensor = torch.from_numpy(patch).cuda()
|
265 |
+
|
266 |
+
with torch.no_grad():
|
267 |
+
y = line_model(tensor)
|
268 |
+
|
269 |
+
yc = y.cpu().numpy()[0, 0, :, :]
|
270 |
+
yc[yc > 255] = 255
|
271 |
+
yc[yc < 0] = 0
|
272 |
+
|
273 |
+
outimg = yc[0:src.shape[0], 0:src.shape[1]]
|
274 |
+
outimg = outimg.astype(np.uint8)
|
275 |
+
outimg = Image.fromarray(outimg)
|
276 |
+
torch.cuda.empty_cache()
|
277 |
+
return outimg
|
278 |
+
|
279 |
+
def to_screen_image(input_image):
|
280 |
+
global opt
|
281 |
+
global ScreenModel
|
282 |
+
input_image = input_image.convert('RGB')
|
283 |
+
input_image = get_ScreenVAE_input(input_image, opt)
|
284 |
+
h = input_image['h']
|
285 |
+
w = input_image['w']
|
286 |
+
ScreenModel.set_input(input_image)
|
287 |
+
fake_B, fake_B2, SCR = ScreenModel.forward(AtoB=True)
|
288 |
+
images=fake_B2[:,:,:h,:w]
|
289 |
+
im = util.tensor2im(images)
|
290 |
+
image_pil = Image.fromarray(im)
|
291 |
+
torch.cuda.empty_cache()
|
292 |
+
return image_pil
|
293 |
+
|
294 |
+
def extract_line_image(query_image_, input_style, resolution):
|
295 |
+
if resolution == "640x640":
|
296 |
+
tar_width = 640
|
297 |
+
tar_height = 640
|
298 |
+
elif resolution == "512x800":
|
299 |
+
tar_width = 512
|
300 |
+
tar_height = 800
|
301 |
+
elif resolution == "800x512":
|
302 |
+
tar_width = 800
|
303 |
+
tar_height = 512
|
304 |
+
else:
|
305 |
+
gr.Info("Unsupported resolution")
|
306 |
+
|
307 |
+
query_image = process_image(query_image_, int(tar_width*1.5), int(tar_height*1.5))
|
308 |
+
if input_style == "GrayImage(ScreenStyle)":
|
309 |
+
extracted_line = to_screen_image(query_image)
|
310 |
+
extracted_line = Image.blend(extracted_line.convert('L').convert('RGB'), query_image.convert('L').convert('RGB'), 0.5)
|
311 |
+
input_context = extracted_line
|
312 |
+
elif input_style == "Sketch":
|
313 |
+
query_image = query_image.convert('L').convert('RGB')
|
314 |
+
extracted_line = extract_lines(query_image)
|
315 |
+
extracted_line = extracted_line.convert('L').convert('RGB')
|
316 |
+
input_context = extracted_line
|
317 |
+
torch.cuda.empty_cache()
|
318 |
+
return input_context, extracted_line, input_context
|
319 |
+
|
320 |
+
def colorize_image(VAE_input, input_context, reference_images, resolution, seed, input_style, num_inference_steps):
|
321 |
+
if VAE_input is None or input_context is None:
|
322 |
+
gr.Info("Please preprocess the image first")
|
323 |
+
raise ValueError("Please preprocess the image first")
|
324 |
+
global cur_input_style
|
325 |
+
global pipeline
|
326 |
+
global MultiResNetModel
|
327 |
+
if input_style != cur_input_style:
|
328 |
+
gr.Info(f"Loading {input_style} model...")
|
329 |
+
load_ckpt(input_style)
|
330 |
+
cur_input_style = input_style
|
331 |
+
gr.Info(f"{input_style} model loaded")
|
332 |
+
reference_images = process_multi_images(reference_images)
|
333 |
+
fix_random_seeds(seed)
|
334 |
+
if resolution == "640x640":
|
335 |
+
tar_width = 640
|
336 |
+
tar_height = 640
|
337 |
+
elif resolution == "512x800":
|
338 |
+
tar_width = 512
|
339 |
+
tar_height = 800
|
340 |
+
elif resolution == "800x512":
|
341 |
+
tar_width = 800
|
342 |
+
tar_height = 512
|
343 |
+
else:
|
344 |
+
gr.Info("Unsupported resolution")
|
345 |
+
validation_mask = Image.open('./assets/mask.png').convert('RGB').resize((tar_width*2, tar_height*2))
|
346 |
+
gr.Info("Image retrieval in progress...")
|
347 |
+
query_image_bw = process_image(input_context, int(tar_width), int(tar_height))
|
348 |
+
query_image = query_image_bw.convert('RGB')
|
349 |
+
query_image_vae = process_image(VAE_input, int(tar_width*1.5), int(tar_height*1.5))
|
350 |
+
reference_images = [process_image(ref_image, tar_width, tar_height) for ref_image in reference_images]
|
351 |
+
query_patches_pil = process_image_Q_varres(query_image, tar_width, tar_height)
|
352 |
+
reference_patches_pil = []
|
353 |
+
for reference_image in reference_images:
|
354 |
+
reference_patches_pil += process_image_ref_varres(reference_image, tar_width, tar_height)
|
355 |
+
combined_image = None
|
356 |
+
with torch.no_grad():
|
357 |
+
clip_img = image_processor(images=query_patches_pil, return_tensors="pt").pixel_values.to(image_encoder.device, dtype=image_encoder.dtype)
|
358 |
+
query_embeddings = image_encoder(clip_img).image_embeds
|
359 |
+
reference_patches_pil_gray = [rimg.convert('RGB').convert('RGB') for rimg in reference_patches_pil]
|
360 |
+
clip_img = image_processor(images=reference_patches_pil_gray, return_tensors="pt").pixel_values.to(image_encoder.device, dtype=image_encoder.dtype)
|
361 |
+
reference_embeddings = image_encoder(clip_img).image_embeds
|
362 |
+
cosine_similarities = F.cosine_similarity(query_embeddings.unsqueeze(1), reference_embeddings.unsqueeze(0), dim=-1)
|
363 |
+
sorted_indices = torch.argsort(cosine_similarities, descending=True, dim=1).tolist()
|
364 |
+
top_k = 3
|
365 |
+
top_k_indices = [cur_sortlist[:top_k] for cur_sortlist in sorted_indices]
|
366 |
+
combined_image = Image.new('RGB', (tar_width * 2, tar_height * 2), 'white')
|
367 |
+
combined_image.paste(query_image_bw.resize((tar_width, tar_height)), (tar_width//2, tar_height//2))
|
368 |
+
idx_table = {0:[(1,0), (0,1), (0,0)], 1:[(1,3), (0,2),(0,3)], 2:[(2,0),(3,1), (3,0)], 3:[(2,3), (3,2),(3,3)]}
|
369 |
+
for i in range(2):
|
370 |
+
for j in range(2):
|
371 |
+
idx_list = idx_table[i * 2 + j]
|
372 |
+
for k in range(top_k):
|
373 |
+
ref_index = top_k_indices[i * 2 + j][k]
|
374 |
+
idx_y = idx_list[k][0]
|
375 |
+
idx_x = idx_list[k][1]
|
376 |
+
combined_image.paste(reference_patches_pil[ref_index].resize((tar_width//2-2, tar_height//2-2)), (tar_width//2 * idx_x + 1, tar_height//2 * idx_y + 1))
|
377 |
+
gr.Info("Model inference in progress...")
|
378 |
+
generator = torch.Generator(device='cuda').manual_seed(seed)
|
379 |
+
image = pipeline(
|
380 |
+
"manga", cond_image=combined_image, cond_mask=validation_mask, num_inference_steps=num_inference_steps, generator=generator
|
381 |
+
).images[0]
|
382 |
+
gr.Info("Post-processing image...")
|
383 |
+
with torch.no_grad():
|
384 |
+
width, height = image.size
|
385 |
+
new_width = width // 2
|
386 |
+
new_height = height // 2
|
387 |
+
left = (width - new_width) // 2
|
388 |
+
top = (height - new_height) // 2
|
389 |
+
right = left + new_width
|
390 |
+
bottom = top + new_height
|
391 |
+
center_crop = image.crop((left, top, right, bottom))
|
392 |
+
up_img = center_crop.resize(query_image_vae.size)
|
393 |
+
test_low_color = transform(up_img).unsqueeze(0).to('cuda', dtype=weight_dtype)
|
394 |
+
query_image_vae = transform(query_image_vae).unsqueeze(0).to('cuda', dtype=weight_dtype)
|
395 |
+
|
396 |
+
h_color, hidden_list_color = pipeline.vae._encode(test_low_color,return_dict = False, hidden_flag = True)
|
397 |
+
h_bw, hidden_list_bw = pipeline.vae._encode(query_image_vae, return_dict = False, hidden_flag = True)
|
398 |
+
|
399 |
+
hidden_list_double = [torch.cat((hidden_list_color[hidden_idx], hidden_list_bw[hidden_idx]), dim = 1) for hidden_idx in range(len(hidden_list_color))]
|
400 |
+
|
401 |
+
|
402 |
+
hidden_list = MultiResNetModel(hidden_list_double)
|
403 |
+
output = pipeline.vae._decode(h_color.sample(),return_dict = False, hidden_list = hidden_list)[0]
|
404 |
+
|
405 |
+
output[output > 1] = 1
|
406 |
+
output[output < -1] = -1
|
407 |
+
high_res_image = Image.fromarray(((output[0] * 0.5 + 0.5).permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8)).convert("RGB")
|
408 |
+
gr.Info("Colorization complete!")
|
409 |
+
torch.cuda.empty_cache()
|
410 |
+
return high_res_image, up_img, image, query_image_bw
|
411 |
+
|
412 |
+
with gr.Blocks() as demo:
|
413 |
+
gr.HTML(
|
414 |
+
"""
|
415 |
+
<div style="text-align: center;">
|
416 |
+
<h1 style="text-align: center; font-size: 3em;">🎨 ColorFlow:</h1>
|
417 |
+
<h3 style="text-align: center; font-size: 1.8em;">Retrieval-Augmented Image Sequence Colorization</h3>
|
418 |
+
<p style="text-align: center; font-weight: bold;">
|
419 |
+
<a href="https://zhuang2002.github.io/ColorFlow/">Project Page</a> |
|
420 |
+
<a href="https://arxiv.org/abs/">ArXiv Preprint</a> |
|
421 |
+
<a href="https://github.com/TencentARC/ColorFlow">GitHub Repository</a>
|
422 |
+
</p>
|
423 |
+
<p style="text-align: center; font-weight: bold;">
|
424 |
+
NOTE: Each time you switch the input style, the corresponding model will be reloaded, which may take some time. Please be patient.
|
425 |
+
</p>
|
426 |
+
<p style="text-align: left; font-size: 1.1em;">
|
427 |
+
Welcome to the demo of <strong>ColorFlow</strong>. Follow the steps below to explore the capabilities of our model:
|
428 |
+
</p>
|
429 |
+
</div>
|
430 |
+
<div style="text-align: left; margin: 0 auto;">
|
431 |
+
<ol style="font-size: 1.1em;">
|
432 |
+
<li>Choose input style: GrayImage(ScreenStyle) or Sketch.</li>
|
433 |
+
<li>Upload your image: Use the 'Upload' button to select the image you want to colorize.</li>
|
434 |
+
<li>Preprocess the image: Click the 'Preprocess' button to decolorize the image.</li>
|
435 |
+
<li>Upload reference images: Upload multiple reference images to guide the colorization.</li>
|
436 |
+
<li>Set sampling parameters (optional): Adjust the settings and click the <b>Colorize</b> button.</li>
|
437 |
+
</ol>
|
438 |
+
<p>
|
439 |
+
⏱️ <b>ZeroGPU Time Limit</b>: Hugging Face ZeroGPU has an inference time limit of 180 seconds. You may need to log in with a free account to use this demo. Large sampling steps might lead to timeout (GPU Abort). In that case, please consider logging in with a Pro account or running it on your local machine.
|
440 |
+
</p>
|
441 |
+
</div>
|
442 |
+
<div style="text-align: center;">
|
443 |
+
<p style="text-align: center; font-weight: bold;">
|
444 |
+
注意:每次切换输入样式时,相应的模型将被重新加载,可能需要一些时间。请耐心等待。
|
445 |
+
</p>
|
446 |
+
<p style="text-align: left; font-size: 1.1em;">
|
447 |
+
欢迎使用 <strong>ColorFlow</strong> 演示。请按照以下步骤探索我们模型的能力:
|
448 |
+
</p>
|
449 |
+
</div>
|
450 |
+
<div style="text-align: left; margin: 0 auto;">
|
451 |
+
<ol style="font-size: 1.1em;">
|
452 |
+
<li>选择输入样式:灰度图(ScreenStyle)、线稿。</li>
|
453 |
+
<li>上传您的图像:使用“上传”按钮选择要上色的图像。</li>
|
454 |
+
<li>预处理图像:点击“预处理”按钮以去色图像。</li>
|
455 |
+
<li>上传参考图像:上传多张参考图像以指导上色。</li>
|
456 |
+
<li>设置采样参数(可选):调整设置并点击 <b>上色</b> 按钮。</li>
|
457 |
+
</ol>
|
458 |
+
<p>
|
459 |
+
⏱️ <b>ZeroGPU时间限制</b>:Hugging Face ZeroGPU 的推理时间限制为 180 秒。您可能需要使用免费帐户登录以使用此演示。大采样步骤可能会导致超时(GPU 中止)。在这种情况下,请考虑使用专业帐户登录或在本地计算机上运行。
|
460 |
+
</p>
|
461 |
+
</div>
|
462 |
+
"""
|
463 |
+
)
|
464 |
+
VAE_input = gr.State()
|
465 |
+
input_context = gr.State()
|
466 |
+
# example_loading = gr.State(value=None)
|
467 |
+
|
468 |
+
with gr.Column():
|
469 |
+
with gr.Row():
|
470 |
+
input_style = gr.Radio(["GrayImage(ScreenStyle)", "Sketch"], label="Input Style", value="GrayImage(ScreenStyle)")
|
471 |
+
with gr.Row():
|
472 |
+
with gr.Column():
|
473 |
+
input_image = gr.Image(type="pil", label="Image to Colorize")
|
474 |
+
resolution = gr.Radio(["640x640", "512x800", "800x512"], label="Select Resolution(Width*Height)", value="640x640")
|
475 |
+
extract_button = gr.Button("Preprocess (Decolorize)")
|
476 |
+
extracted_image = gr.Image(type="pil", label="Decolorized Result")
|
477 |
+
with gr.Row():
|
478 |
+
reference_images = gr.Files(label="Reference Images (Upload multiple)", file_count="multiple")
|
479 |
+
with gr.Column():
|
480 |
+
output_gallery = gr.Gallery(label="Colorization Results", type="pil")
|
481 |
+
seed = gr.Slider(label="Random Seed", minimum=0, maximum=100000, value=0, step=1)
|
482 |
+
num_inference_steps = gr.Slider(label="Inference Steps", minimum=4, maximum=100, value=10, step=1)
|
483 |
+
colorize_button = gr.Button("Colorize")
|
484 |
+
|
485 |
+
# progress_text = gr.Textbox(label="Progress", interactive=False)
|
486 |
+
|
487 |
+
|
488 |
+
extract_button.click(
|
489 |
+
extract_line_image,
|
490 |
+
inputs=[input_image, input_style, resolution],
|
491 |
+
outputs=[extracted_image, VAE_input, input_context]
|
492 |
+
)
|
493 |
+
colorize_button.click(
|
494 |
+
colorize_image,
|
495 |
+
inputs=[VAE_input, input_context, reference_images, resolution, seed, input_style, num_inference_steps],
|
496 |
+
outputs=output_gallery
|
497 |
+
)
|
498 |
+
|
499 |
+
with gr.Column():
|
500 |
+
gr.Markdown("### Quick Examples")
|
501 |
+
gr.Examples(
|
502 |
+
examples=examples,
|
503 |
+
inputs=[input_image, reference_images, input_style, resolution, seed, num_inference_steps],
|
504 |
+
label="Examples",
|
505 |
+
examples_per_page=6,
|
506 |
+
)
|
507 |
+
demo.launch(server_name="0.0.0.0", server_port=22348)
|
assets/example_0/input.jpg
ADDED
assets/example_0/ref1.jpg
ADDED
assets/example_1/input.jpg
ADDED
assets/example_1/ref1.jpg
ADDED
assets/example_1/ref2.jpg
ADDED
assets/example_1/ref3.jpg
ADDED
assets/example_2/input.png
ADDED
assets/example_2/ref1.png
ADDED
assets/example_2/ref2.png
ADDED
assets/example_2/ref3.png
ADDED
assets/example_3/input.png
ADDED
assets/example_3/ref1.png
ADDED
assets/example_3/ref2.png
ADDED
assets/example_3/ref3.png
ADDED
assets/example_4/input.jpg
ADDED
assets/example_4/ref1.jpg
ADDED
assets/example_4/ref2.jpg
ADDED
assets/example_4/ref3.jpg
ADDED
assets/example_5/input.png
ADDED
assets/example_5/ref1.png
ADDED
assets/example_5/ref2.png
ADDED
assets/example_5/ref3.png
ADDED
assets/mask.png
ADDED
diffusers/.github/ISSUE_TEMPLATE/bug-report.yml
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: "\U0001F41B Bug Report"
|
2 |
+
description: Report a bug on Diffusers
|
3 |
+
labels: [ "bug" ]
|
4 |
+
body:
|
5 |
+
- type: markdown
|
6 |
+
attributes:
|
7 |
+
value: |
|
8 |
+
Thanks a lot for taking the time to file this issue 🤗.
|
9 |
+
Issues do not only help to improve the library, but also publicly document common problems, questions, workflows for the whole community!
|
10 |
+
Thus, issues are of the same importance as pull requests when contributing to this library ❤️.
|
11 |
+
In order to make your issue as **useful for the community as possible**, let's try to stick to some simple guidelines:
|
12 |
+
- 1. Please try to be as precise and concise as possible.
|
13 |
+
*Give your issue a fitting title. Assume that someone which very limited knowledge of Diffusers can understand your issue. Add links to the source code, documentation other issues, pull requests etc...*
|
14 |
+
- 2. If your issue is about something not working, **always** provide a reproducible code snippet. The reader should be able to reproduce your issue by **only copy-pasting your code snippet into a Python shell**.
|
15 |
+
*The community cannot solve your issue if it cannot reproduce it. If your bug is related to training, add your training script and make everything needed to train public. Otherwise, just add a simple Python code snippet.*
|
16 |
+
- 3. Add the **minimum** amount of code / context that is needed to understand, reproduce your issue.
|
17 |
+
*Make the life of maintainers easy. `diffusers` is getting many issues every day. Make sure your issue is about one bug and one bug only. Make sure you add only the context, code needed to understand your issues - nothing more. Generally, every issue is a way of documenting this library, try to make it a good documentation entry.*
|
18 |
+
- 4. For issues related to community pipelines (i.e., the pipelines located in the `examples/community` folder), please tag the author of the pipeline in your issue thread as those pipelines are not maintained.
|
19 |
+
- type: markdown
|
20 |
+
attributes:
|
21 |
+
value: |
|
22 |
+
For more in-detail information on how to write good issues you can have a look [here](https://huggingface.co/course/chapter8/5?fw=pt).
|
23 |
+
- type: textarea
|
24 |
+
id: bug-description
|
25 |
+
attributes:
|
26 |
+
label: Describe the bug
|
27 |
+
description: A clear and concise description of what the bug is. If you intend to submit a pull request for this issue, tell us in the description. Thanks!
|
28 |
+
placeholder: Bug description
|
29 |
+
validations:
|
30 |
+
required: true
|
31 |
+
- type: textarea
|
32 |
+
id: reproduction
|
33 |
+
attributes:
|
34 |
+
label: Reproduction
|
35 |
+
description: Please provide a minimal reproducible code which we can copy/paste and reproduce the issue.
|
36 |
+
placeholder: Reproduction
|
37 |
+
validations:
|
38 |
+
required: true
|
39 |
+
- type: textarea
|
40 |
+
id: logs
|
41 |
+
attributes:
|
42 |
+
label: Logs
|
43 |
+
description: "Please include the Python logs if you can."
|
44 |
+
render: shell
|
45 |
+
- type: textarea
|
46 |
+
id: system-info
|
47 |
+
attributes:
|
48 |
+
label: System Info
|
49 |
+
description: Please share your system info with us. You can run the command `diffusers-cli env` and copy-paste its output below.
|
50 |
+
placeholder: Diffusers version, platform, Python version, ...
|
51 |
+
validations:
|
52 |
+
required: true
|
53 |
+
- type: textarea
|
54 |
+
id: who-can-help
|
55 |
+
attributes:
|
56 |
+
label: Who can help?
|
57 |
+
description: |
|
58 |
+
Your issue will be replied to more quickly if you can figure out the right person to tag with @.
|
59 |
+
If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of **who to tag**.
|
60 |
+
|
61 |
+
All issues are read by one of the core maintainers, so if you don't know who to tag, just leave this blank and
|
62 |
+
a core maintainer will ping the right person.
|
63 |
+
|
64 |
+
Please tag a maximum of 2 people.
|
65 |
+
|
66 |
+
Questions on DiffusionPipeline (Saving, Loading, From pretrained, ...): @sayakpaul @DN6
|
67 |
+
|
68 |
+
Questions on pipelines:
|
69 |
+
- Stable Diffusion @yiyixuxu @asomoza
|
70 |
+
- Stable Diffusion XL @yiyixuxu @sayakpaul @DN6
|
71 |
+
- Stable Diffusion 3: @yiyixuxu @sayakpaul @DN6 @asomoza
|
72 |
+
- Kandinsky @yiyixuxu
|
73 |
+
- ControlNet @sayakpaul @yiyixuxu @DN6
|
74 |
+
- T2I Adapter @sayakpaul @yiyixuxu @DN6
|
75 |
+
- IF @DN6
|
76 |
+
- Text-to-Video / Video-to-Video @DN6 @a-r-r-o-w
|
77 |
+
- Wuerstchen @DN6
|
78 |
+
- Other: @yiyixuxu @DN6
|
79 |
+
- Improving generation quality: @asomoza
|
80 |
+
|
81 |
+
Questions on models:
|
82 |
+
- UNet @DN6 @yiyixuxu @sayakpaul
|
83 |
+
- VAE @sayakpaul @DN6 @yiyixuxu
|
84 |
+
- Transformers/Attention @DN6 @yiyixuxu @sayakpaul
|
85 |
+
|
86 |
+
Questions on single file checkpoints: @DN6
|
87 |
+
|
88 |
+
Questions on Schedulers: @yiyixuxu
|
89 |
+
|
90 |
+
Questions on LoRA: @sayakpaul
|
91 |
+
|
92 |
+
Questions on Textual Inversion: @sayakpaul
|
93 |
+
|
94 |
+
Questions on Training:
|
95 |
+
- DreamBooth @sayakpaul
|
96 |
+
- Text-to-Image Fine-tuning @sayakpaul
|
97 |
+
- Textual Inversion @sayakpaul
|
98 |
+
- ControlNet @sayakpaul
|
99 |
+
|
100 |
+
Questions on Tests: @DN6 @sayakpaul @yiyixuxu
|
101 |
+
|
102 |
+
Questions on Documentation: @stevhliu
|
103 |
+
|
104 |
+
Questions on JAX- and MPS-related things: @pcuenca
|
105 |
+
|
106 |
+
Questions on audio pipelines: @sanchit-gandhi
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
placeholder: "@Username ..."
|
diffusers/.github/ISSUE_TEMPLATE/config.yml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
contact_links:
|
2 |
+
- name: Questions / Discussions
|
3 |
+
url: https://github.com/huggingface/diffusers/discussions
|
4 |
+
about: General usage questions and community discussions
|
diffusers/.github/ISSUE_TEMPLATE/feature_request.md
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: "\U0001F680 Feature Request"
|
3 |
+
about: Suggest an idea for this project
|
4 |
+
title: ''
|
5 |
+
labels: ''
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
**Is your feature request related to a problem? Please describe.**
|
11 |
+
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...].
|
12 |
+
|
13 |
+
**Describe the solution you'd like.**
|
14 |
+
A clear and concise description of what you want to happen.
|
15 |
+
|
16 |
+
**Describe alternatives you've considered.**
|
17 |
+
A clear and concise description of any alternative solutions or features you've considered.
|
18 |
+
|
19 |
+
**Additional context.**
|
20 |
+
Add any other context or screenshots about the feature request here.
|
diffusers/.github/ISSUE_TEMPLATE/feedback.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: "💬 Feedback about API Design"
|
3 |
+
about: Give feedback about the current API design
|
4 |
+
title: ''
|
5 |
+
labels: ''
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
**What API design would you like to have changed or added to the library? Why?**
|
11 |
+
|
12 |
+
**What use case would this enable or better enable? Can you give us a code example?**
|
diffusers/.github/ISSUE_TEMPLATE/new-model-addition.yml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: "\U0001F31F New Model/Pipeline/Scheduler Addition"
|
2 |
+
description: Submit a proposal/request to implement a new diffusion model/pipeline/scheduler
|
3 |
+
labels: [ "New model/pipeline/scheduler" ]
|
4 |
+
|
5 |
+
body:
|
6 |
+
- type: textarea
|
7 |
+
id: description-request
|
8 |
+
validations:
|
9 |
+
required: true
|
10 |
+
attributes:
|
11 |
+
label: Model/Pipeline/Scheduler description
|
12 |
+
description: |
|
13 |
+
Put any and all important information relative to the model/pipeline/scheduler
|
14 |
+
|
15 |
+
- type: checkboxes
|
16 |
+
id: information-tasks
|
17 |
+
attributes:
|
18 |
+
label: Open source status
|
19 |
+
description: |
|
20 |
+
Please note that if the model implementation isn't available or if the weights aren't open-source, we are less likely to implement it in `diffusers`.
|
21 |
+
options:
|
22 |
+
- label: "The model implementation is available."
|
23 |
+
- label: "The model weights are available (Only relevant if addition is not a scheduler)."
|
24 |
+
|
25 |
+
- type: textarea
|
26 |
+
id: additional-info
|
27 |
+
attributes:
|
28 |
+
label: Provide useful links for the implementation
|
29 |
+
description: |
|
30 |
+
Please provide information regarding the implementation, the weights, and the authors.
|
31 |
+
Please mention the authors by @gh-username if you're aware of their usernames.
|
diffusers/.github/ISSUE_TEMPLATE/translate.md
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: 🌐 Translating a New Language?
|
3 |
+
about: Start a new translation effort in your language
|
4 |
+
title: '[<languageCode>] Translating docs to <languageName>'
|
5 |
+
labels: WIP
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
<!--
|
11 |
+
Note: Please search to see if an issue already exists for the language you are trying to translate.
|
12 |
+
-->
|
13 |
+
|
14 |
+
Hi!
|
15 |
+
|
16 |
+
Let's bring the documentation to all the <languageName>-speaking community 🌐.
|
17 |
+
|
18 |
+
Who would want to translate? Please follow the 🤗 [TRANSLATING guide](https://github.com/huggingface/diffusers/blob/main/docs/TRANSLATING.md). Here is a list of the files ready for translation. Let us know in this issue if you'd like to translate any, and we'll add your name to the list.
|
19 |
+
|
20 |
+
Some notes:
|
21 |
+
|
22 |
+
* Please translate using an informal tone (imagine you are talking with a friend about Diffusers 🤗).
|
23 |
+
* Please translate in a gender-neutral way.
|
24 |
+
* Add your translations to the folder called `<languageCode>` inside the [source folder](https://github.com/huggingface/diffusers/tree/main/docs/source).
|
25 |
+
* Register your translation in `<languageCode>/_toctree.yml`; please follow the order of the [English version](https://github.com/huggingface/diffusers/blob/main/docs/source/en/_toctree.yml).
|
26 |
+
* Once you're finished, open a pull request and tag this issue by including #issue-number in the description, where issue-number is the number of this issue. Please ping @stevhliu for review.
|
27 |
+
* 🙋 If you'd like others to help you with the translation, you can also post in the 🤗 [forums](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63).
|
28 |
+
|
29 |
+
Thank you so much for your help! 🤗
|