xco2 commited on
Commit
ebf6d7b
·
1 Parent(s): 687cb7c
Files changed (2) hide show
  1. net/UNet.py +0 -96
  2. requirements.txt +2 -179
net/UNet.py CHANGED
@@ -422,99 +422,3 @@ class UNet(nn.Module):
422
  # print("decoder:")
423
  # print(decoder_out.shape)
424
  return decoder_out
425
-
426
-
427
- if __name__ == '__main__':
428
- import cv2, os
429
-
430
-
431
- def modelSave(model, save_path, save_name):
432
- if not os.path.exists(save_path):
433
- os.mkdir(save_path)
434
- torch.save(model.state_dict(), os.path.join(save_path, save_name))
435
-
436
-
437
- def merge_images(images: np.ndarray):
438
- """
439
- 合并图像
440
- :param images: 图像数组
441
- :return: 合并后的图像数组
442
- """
443
- n, h, w, c = images.shape
444
- nn = int(np.ceil(n ** 0.5))
445
- merged_image = np.zeros((h * nn, w * nn, 3), dtype=images.dtype)
446
- for i in range(n):
447
- row = i // nn
448
- col = i % nn
449
- merged_image[row * h:(row + 1) * h, col * w:(col + 1) * w, :] = images[i]
450
-
451
- merged_image = np.clip(merged_image, 0, 255)
452
- merged_image = np.array(merged_image, dtype=np.uint8)
453
- return merged_image
454
-
455
-
456
- # 320,448,576,832
457
- config = { # 模型结构相关
458
- "en_out_c": (256, 256, 256, 320, 320, 320, 576, 576, 576, 704, 704, 704),
459
- "en_down": (0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0),
460
- "en_skip": (0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1),
461
- "en_att_heads": (8, 8, 8, 0, 8, 8, 0, 8, 8, 0, 8, 8),
462
- "de_out_c": (704, 576, 576, 576, 320, 320, 320, 256, 256, 256, 256),
463
- "de_up": ("none", "subpix", "none", "none", "subpix", "none", "none", "subpix", "none", "none", "none"),
464
- "de_skip": (1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0),
465
- "de_att_heads": (8, 8, 0, 8, 8, 0, 8, 8, 0, 8, 8), # skip的地方不做self-attention
466
- "t_out_c": 256,
467
- "vae_c": 4,
468
- "block_deep": 3,
469
- }
470
- device = "cuda"
471
- total_step = 1000
472
-
473
- unet = UNet(config["en_out_c"], config["en_down"], config["en_skip"], config["en_att_heads"],
474
- config["de_out_c"], config["de_up"], config["de_skip"], config["de_att_heads"],
475
- config["t_out_c"], config["vae_c"], config["block_deep"]).to(device)
476
-
477
- print("总参数", sum(i.numel() for i in unet.parameters()) / 10000, "单位:万")
478
- print("encoder", sum(i.numel() for i in unet.encoder.parameters()) / 10000, "单位:万")
479
- print("decoder", sum(i.numel() for i in unet.decoder.parameters()) / 10000, "单位:万")
480
- print("t", sum(i.numel() for i in unet.t_encoder.parameters()) / 10000, "单位:万")
481
-
482
- batch_size = 2
483
- x = np.random.random((batch_size, config["vae_c"], 32, 32))
484
- t = np.random.uniform(1, total_step + 0.9999, size=(batch_size, 1))
485
- t = np.array(t, dtype=np.int16)
486
- t = t / total_step
487
-
488
- with torch.no_grad():
489
- x = torch.Tensor(x).to(device)
490
- t = torch.Tensor(t).to(device)
491
- y = unet(x, t)
492
- print(y.shape)
493
-
494
- z = y[0].cpu().numpy()
495
- # z = (z - np.mean(z)) / (np.max(z) - np.min(z))
496
- z = np.clip(np.asarray((z + 1) * 127.5), 0, 255)
497
- z = np.asarray(z, dtype=np.uint8)
498
-
499
- z = [np.tile(z[ii, :, :, np.newaxis], (1, 1, 3)) for ii in range(z.shape[0])]
500
- noise = merge_images(np.array(z))
501
-
502
- noise = cv2.resize(noise, None, fx=2, fy=2)
503
- cv2.imshow("noise", noise)
504
- cv2.waitKey(0)
505
-
506
- # modelSave(unet, "./", "test.pth")
507
- # 导出为onnx格式
508
- torch.onnx.export(
509
- unet,
510
- (x, t),
511
- 'unet.onnx',
512
- export_params=True,
513
- opset_version=12,
514
- )
515
- import onnx
516
-
517
- # 增加维度信息
518
- model_file = 'unet.onnx'
519
- onnx_model = onnx.load(model_file)
520
- onnx.save(onnx.shape_inference.infer_shapes(onnx_model), model_file)
 
422
  # print("decoder:")
423
  # print(decoder_out.shape)
424
  return decoder_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,186 +1,9 @@
1
- absl-py==1.3.0
2
- addict==2.4.0
3
- aiofiles==23.1.0
4
- aiohttp==3.8.3
5
- aiosignal==1.3.1
6
- aliyun-python-sdk-core==2.13.36
7
- aliyun-python-sdk-kms==2.16.0
8
- altair==4.2.0
9
- anyio==3.6.2
10
- appdirs==1.4.4
11
- asttokens==2.3.0
12
- async-timeout==4.0.2
13
- attrs==22.1.0
14
- audioread==3.0.0
15
- backcall==0.2.0
16
- certifi==2022.12.7
17
- cffi==1.15.1
18
- charset-normalizer==2.1.1
19
- chumpy==0.70
20
- click==8.1.3
21
- clip==1.0
22
- colorama==0.4.6
23
- commonmark==0.9.1
24
- contourpy==1.0.6
25
- cpm-kernels==1.0.11
26
- crcmod==1.7
27
- cryptography==39.0.2
28
- cycler==0.11.0
29
- Cython==0.29.32
30
- datasets==2.8.0
31
- decorator==5.1.1
32
- decord==0.6.0
33
- diffusers==0.20.1
34
- dill==0.3.6
35
- docker-pycreds==0.4.0
36
- einops==0.6.0
37
- entrypoints==0.4
38
- exceptiongroup==1.1.3
39
- executing==1.2.0
40
- fastapi==0.88.0
41
- ffmpy==0.3.0
42
- filelock==3.8.2
43
- Flask==2.0.2
44
- Flask-Cors==3.0.10
45
- fonttools==4.38.0
46
- frozenlist==1.3.3
47
- fsspec==2022.11.0
48
- ftfy==6.1.1
49
- gast==0.5.3
50
- gitdb==4.0.10
51
- GitPython==3.1.32
52
- gradio==3.39.0
53
- gradio_client==0.3.0
54
- h11==0.14.0
55
- httpcore==0.16.2
56
- httpx==0.23.1
57
  huggingface-hub==0.16.4
58
- icetk==0.0.4
59
- idna==3.4
60
- importlib-metadata==5.2.0
61
- ipython==8.15.0
62
- itsdangerous==2.1.2
63
- jedi==0.19.0
64
- Jinja2==3.1.2
65
- jmespath==0.10.0
66
- joblib==1.2.0
67
- json-tricks==3.16.1
68
- jsonplus==0.8.0
69
- jsonschema==4.17.3
70
- kiwisolver==1.4.4
71
- lazy_loader==0.1
72
- librosa==0.10.0
73
- linkify-it-py==1.0.3
74
- lion-pytorch==0.1.2
75
- llvmlite==0.39.1
76
- loguru==0.6.0
77
- Markdown==3.4.1
78
- markdown-it-py==2.1.0
79
- MarkupSafe==2.1.1
80
- matplotlib==3.6.2
81
- matplotlib-inline==0.1.6
82
- mdit-py-plugins==0.3.3
83
- mdurl==0.1.2
84
- mediapipe==0.8.11
85
- mmcv-full==1.7.0
86
- mmdet==2.26.0
87
- model-index==0.1.11
88
- modelscope==1.3.2
89
- mpmath==1.2.1
90
- msgpack==1.0.4
91
- multidict==6.0.3
92
- multiprocess==0.70.14
93
- munkres==1.1.4
94
- networkx==3.0
95
- numba==0.56.4
96
  numpy==1.23.4
97
- onnx==1.14.1
98
- opencv-contrib-python==4.5.5.64
99
- opencv-python==4.5.5.64
100
- openmim==0.3.3
101
- ordered-set==4.1.0
102
- orjson==3.8.3
103
- oss2==2.16.0
104
- packaging==21.3
105
- pandas==1.5.2
106
- parso==0.8.3
107
- pathtools==0.1.2
108
- pickleshare==0.7.5
109
- Pillow==9.2.0
110
- pip==23.1.2
111
- platformdirs==3.1.0
112
- plotly==5.11.0
113
- pooch==1.7.0
114
- prodigyopt==1.0
115
- prompt-toolkit==3.0.39
116
- protobuf==4.24.2
117
- psutil==5.9.5
118
- pure-eval==0.2.2
119
- pyarrow==11.0.0
120
- pycocotools==2.0.6
121
- pycparser==2.21
122
- pycryptodome==3.16.0
123
- pydantic==1.10.2
124
- pydub==0.25.1
125
- Pygments==2.13.0
126
- pyparsing==3.0.9
127
- pyrsistent==0.19.2
128
- python-dateutil==2.8.2
129
- python-multipart==0.0.5
130
- pytorch-fid==0.3.0
131
- pytz==2022.6
132
- PyYAML==6.0
133
- regex==2022.10.31
134
- requests==2.28.1
135
- responses==0.18.0
136
- rfc3986==1.5.0
137
- rich==12.6.0
138
- safetensors==0.3.3
139
- scikit-learn==1.2.1
140
- scipy==1.9.3
141
- semantic-version==2.10.0
142
- sentencepiece==0.1.97
143
- sentry-sdk==1.28.0
144
- setproctitle==1.3.2
145
- setuptools==65.5.0
146
- simplejson==3.18.3
147
- six==1.16.0
148
- smmap==5.0.0
149
- sniffio==1.3.0
150
- sortedcontainers==2.4.0
151
- soundfile==0.12.1
152
- soxr==0.3.4
153
- stack-data==0.6.2
154
- starlette==0.22.0
155
- sympy==1.11.1
156
- tabulate==0.9.0
157
- tenacity==8.1.0
158
- terminaltables==3.1.10
159
- threadpoolctl==3.1.0
160
- timm==0.4.9
161
- tokenizers==0.13.2
162
- toolz==0.12.0
163
  torch==2.0.0+cu117
164
  torchaudio==2.0.1+cu117
165
  torchinfo==1.7.1
166
  torchvision==0.15.1+cu117
167
  tqdm==4.64.1
168
- traitlets==5.9.0
169
- transformers==4.26.1
170
- typing_extensions==4.4.0
171
- uc-micro-py==1.0.1
172
- unicodedata2==15.0.0
173
- urllib3==1.26.12
174
- uvicorn==0.20.0
175
- wandb==0.15.5
176
- wcwidth==0.2.5
177
- websockets==10.4
178
- Werkzeug==2.2.2
179
- wheel==0.37.1
180
- win32-setctime==1.1.0
181
- wincertstore==0.2
182
- xtcocotools==1.12
183
- xxhash==3.2.0
184
- yapf==0.32.0
185
- yarl==1.8.2
186
- zipp==3.11.0
 
1
+ gradio
2
+ gradio_client
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  huggingface-hub==0.16.4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  numpy==1.23.4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  torch==2.0.0+cu117
6
  torchaudio==2.0.1+cu117
7
  torchinfo==1.7.1
8
  torchvision==0.15.1+cu117
9
  tqdm==4.64.1