update from climategan space
Browse files- README.md +29 -4
- app.py +274 -41
- climategan/trainer.py +41 -25
- climategan/tutils.py +46 -4
- climategan/utils.py +3 -3
- climategan_wrapper.py +624 -0
- inferences.py +0 -108
- requirements.txt +169 -0
README.md
CHANGED
@@ -11,11 +11,10 @@ emoji: 🌎
|
|
11 |
colorFrom: blue
|
12 |
colorTo: green
|
13 |
sdk: gradio
|
14 |
-
sdk_version:
|
15 |
app_file: app.py
|
16 |
inference: true
|
17 |
-
|
18 |
-
# -
|
19 |
---
|
20 |
|
21 |
# ClimateGAN: Raising Awareness about Climate Change by Generating Images of Floods
|
@@ -38,7 +37,33 @@ If you use this code, data or pre-trained weights, please cite our ICLR 2022 pap
|
|
38 |
}
|
39 |
```
|
40 |
|
41 |
-
## Using pre-trained weights
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
In the paper, we present ClimateGAN as a solution to produce images of floods. It can actually do **more**:
|
44 |
|
|
|
11 |
colorFrom: blue
|
12 |
colorTo: green
|
13 |
sdk: gradio
|
14 |
+
sdk_version: 3.6
|
15 |
app_file: app.py
|
16 |
inference: true
|
17 |
+
pinned: true
|
|
|
18 |
---
|
19 |
|
20 |
# ClimateGAN: Raising Awareness about Climate Change by Generating Images of Floods
|
|
|
37 |
}
|
38 |
```
|
39 |
|
40 |
+
## Using pre-trained weights from this Huggingface Space and Stable Diffusion In-painting
|
41 |
+
|
42 |
+
[**Huggingface ClimateGAN Space**](https://huggingface.co/spaces/vict0rsch/climateGAN)
|
43 |
+
|
44 |
+
1. Download code and model
|
45 |
+
```
|
46 |
+
git clone https://huggingface.co/spaces/vict0rsch/climateGAN hf-spaces-climategan
|
47 |
+
cd hf-spaces-climategan
|
48 |
+
git lfs install
|
49 |
+
git lfs pull
|
50 |
+
```
|
51 |
+
2. Install requirements
|
52 |
+
```
|
53 |
+
pip install requirements.txt
|
54 |
+
```
|
55 |
+
3. **Enable Stable Diffusion Inpainting** by visiting the model's card: https://huggingface.co/runwayml/stable-diffusion-inpainting **and** running `$ huggingface-cli login`
|
56 |
+
4. Run `$ python climategan_wrapper.py help` for usage instructions on how to infer on a folder's images.
|
57 |
+
5. Run `$ python app.py` to see the Gradio app.
|
58 |
+
1. To use Google Street View you'll need an API key and set the `GMAPS_API_KEY` environment variable.
|
59 |
+
2. To use Stable Diffusion if you can't run `$ huggingface-cli login` (on a Huggingface Space for instance) set the `HF_AUTH_TOKEN` env variable to a [Huggingface authorization token](https://huggingface.co/settings/tokens)
|
60 |
+
3. To change the UI without model overhead, set the `CG_DEV_MODE` environment variable to `true`.
|
61 |
+
|
62 |
+
For a more fine-grained control on ClimateGAN's inferences, refer to `apply_events.py` (does not support Stable Diffusion painter)
|
63 |
+
|
64 |
+
**Note:** you don't have control on the prompt by design because I disabled the safety checker. Fork this space/repo and do it yourself if you really need to change the prompt. At least [open a discussion](https://huggingface.co/spaces/vict0rsch/climateGAN/discussions).
|
65 |
+
|
66 |
+
## Using pre-trained weights from source
|
67 |
|
68 |
In the paper, we present ClimateGAN as a solution to produce images of floods. It can actually do **more**:
|
69 |
|
app.py
CHANGED
@@ -2,69 +2,302 @@
|
|
2 |
# thank you @NimaBoscarino
|
3 |
|
4 |
import os
|
5 |
-
|
|
|
|
|
|
|
6 |
import googlemaps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from skimage import io
|
8 |
-
from
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
|
|
13 |
def _predict(*args):
|
14 |
-
print("
|
15 |
-
image = place = None
|
16 |
-
if
|
17 |
-
image = args
|
18 |
else:
|
19 |
-
|
20 |
-
image, place = args
|
21 |
|
22 |
-
if api_key and place:
|
23 |
geocode_result = gmaps.geocode(place)
|
24 |
|
25 |
address = geocode_result[0]["formatted_address"]
|
26 |
static_map_url = f"https://maps.googleapis.com/maps/api/streetview?size=640x640&location={parse.quote(address)}&source=outdoor&key={api_key}"
|
27 |
img_np = io.imread(static_map_url)
|
|
|
28 |
else:
|
|
|
29 |
img_np = image
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
return _predict
|
34 |
|
35 |
|
36 |
if __name__ == "__main__":
|
37 |
|
|
|
|
|
|
|
38 |
api_key = os.environ.get("GMAPS_API_KEY")
|
39 |
gmaps = None
|
40 |
if api_key is not None:
|
41 |
gmaps = googlemaps.Client(key=api_key)
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
# thank you @NimaBoscarino
|
3 |
|
4 |
import os
|
5 |
+
from textwrap import dedent
|
6 |
+
from urllib import parse
|
7 |
+
from requests import get
|
8 |
+
|
9 |
import googlemaps
|
10 |
+
import gradio as gr
|
11 |
+
import numpy as np
|
12 |
+
from gradio.components import (
|
13 |
+
HTML,
|
14 |
+
Button,
|
15 |
+
Column,
|
16 |
+
Dropdown,
|
17 |
+
Image,
|
18 |
+
Markdown,
|
19 |
+
Radio,
|
20 |
+
Row,
|
21 |
+
Textbox,
|
22 |
+
)
|
23 |
from skimage import io
|
24 |
+
from datetime import datetime
|
25 |
+
|
26 |
+
from climategan_wrapper import ClimateGAN
|
27 |
+
|
28 |
+
TEXTS = [
|
29 |
+
dedent(
|
30 |
+
"""
|
31 |
+
<p>
|
32 |
+
Climate change does not impact everyone equally.
|
33 |
+
This Space shows the effects of the climate emergency,
|
34 |
+
"one address at a time".
|
35 |
+
Visit the original experience at
|
36 |
+
<a href="https://thisclimatedoesnotexist.com/">
|
37 |
+
ThisClimateDoesNotExist.com
|
38 |
+
</a>
|
39 |
+
</p>
|
40 |
+
<br>
|
41 |
+
<p>
|
42 |
+
Enter an address or upload a Street View image, and ClimateGAN
|
43 |
+
will generate images showing how the location could be impacted
|
44 |
+
by flooding, wildfires, or smog if it happened there.
|
45 |
+
</p>
|
46 |
+
<br>
|
47 |
+
<p>
|
48 |
+
This is <strong>NOT</strong> an exercise in climate prediction,
|
49 |
+
rather an exercise of empathy, to put yourself in others' shoes,
|
50 |
+
as if Climate Change came crushing on your doorstep.
|
51 |
+
</p>
|
52 |
+
<br>
|
53 |
+
<p>
|
54 |
+
After you have selected an image and started the inference you
|
55 |
+
will see all the outputs of ClimateGAN, including intermediate
|
56 |
+
outputs such as the flood mask, the segmentation map and the
|
57 |
+
depth maps used to produce the 3 events.
|
58 |
+
</p>
|
59 |
+
<br>
|
60 |
+
<p>
|
61 |
+
This Space makes use of recent Stable Diffusion in-painting
|
62 |
+
pipelines to replace ClimateGAN's original Painter. If you
|
63 |
+
select 'Both' painters, you will see a comparison
|
64 |
+
</p>
|
65 |
+
<br>
|
66 |
+
<p style='text-align: center'>
|
67 |
+
Visit
|
68 |
+
<a href='https://thisclimatedoesnotexist.com/'>
|
69 |
+
ThisClimateDoesNotExist.com</a>
|
70 |
+
for more information
|
71 |
+
|
|
72 |
+
Original
|
73 |
+
<a href='https://github.com/cc-ai/climategan'>
|
74 |
+
ClimateGAN GitHub Repo
|
75 |
+
</a>
|
76 |
+
|
|
77 |
+
Read the original
|
78 |
+
<a
|
79 |
+
href='https://openreview.net/forum?id=EZNOb_uNpJk'
|
80 |
+
target='_blank'>
|
81 |
+
ICLR 2021 ClimateGAN paper
|
82 |
+
</a>
|
83 |
+
</p>
|
84 |
+
"""
|
85 |
+
),
|
86 |
+
dedent(
|
87 |
+
"""
|
88 |
+
## How to use this Space
|
89 |
+
|
90 |
+
1. Enter an address or upload a Street View image
|
91 |
+
2. Select the type of Painter you'd like to use for the flood renderings
|
92 |
+
3. Click on the "See for yourself!" button
|
93 |
+
4. Wait for the inference to complete, typically around 30 seconds
|
94 |
+
(plus queue time)
|
95 |
+
5. Enjoy the results!
|
96 |
+
|
97 |
+
1. The prompt for Stable Diffusion is `An HD picture of a street with
|
98 |
+
dirty water after a heavy flood`
|
99 |
+
2. Pay attention to potential "inventions" by Stable Diffusion's in-painting
|
100 |
+
3. The "restricted to masked area" SD output is the result of:
|
101 |
+
`y = mask * flooded + (1-mask) * input`
|
102 |
+
|
103 |
+
"""
|
104 |
+
),
|
105 |
+
]
|
106 |
+
CSS = dedent(
|
107 |
+
"""
|
108 |
+
a {
|
109 |
+
color: #0088ff;
|
110 |
+
text-decoration: underline;
|
111 |
+
}
|
112 |
+
strong {
|
113 |
+
color: #c34318;
|
114 |
+
font-weight: bolder;
|
115 |
+
}
|
116 |
+
#how-to-use-md li {
|
117 |
+
margin: 0.1em;
|
118 |
+
}
|
119 |
+
#how-to-use-md li p {
|
120 |
+
margin: 0.1em;
|
121 |
+
}
|
122 |
+
"""
|
123 |
+
)
|
124 |
+
|
125 |
|
126 |
+
def toggle(radio):
|
127 |
+
if "address" in radio.lower():
|
128 |
+
return [
|
129 |
+
gr.update(visible=True),
|
130 |
+
gr.update(visible=False),
|
131 |
+
gr.update(visible=True),
|
132 |
+
]
|
133 |
+
else:
|
134 |
+
return [
|
135 |
+
gr.update(visible=False),
|
136 |
+
gr.update(visible=True),
|
137 |
+
gr.update(visible=True),
|
138 |
+
]
|
139 |
|
140 |
+
|
141 |
+
def predict(cg: ClimateGAN, api_key):
|
142 |
def _predict(*args):
|
143 |
+
print(f"Starting inference ({str(datetime.now())})")
|
144 |
+
image = place = painter = radio = None
|
145 |
+
if api_key:
|
146 |
+
radio, image, place, painter = args
|
147 |
else:
|
148 |
+
image, painter = args
|
|
|
149 |
|
150 |
+
if api_key and place and "address" in radio.lower():
|
151 |
geocode_result = gmaps.geocode(place)
|
152 |
|
153 |
address = geocode_result[0]["formatted_address"]
|
154 |
static_map_url = f"https://maps.googleapis.com/maps/api/streetview?size=640x640&location={parse.quote(address)}&source=outdoor&key={api_key}"
|
155 |
img_np = io.imread(static_map_url)
|
156 |
+
print("Using GSV image")
|
157 |
else:
|
158 |
+
print("Using user image")
|
159 |
img_np = image
|
160 |
+
|
161 |
+
painters = {
|
162 |
+
"ClimateGAN Painter": "climategan",
|
163 |
+
"Stable Diffusion Painter": "stable_diffusion",
|
164 |
+
"Both": "both",
|
165 |
+
}
|
166 |
+
print("Using painter", painters[painter])
|
167 |
+
output_dict = cg.infer_single(
|
168 |
+
img_np,
|
169 |
+
painters[painter],
|
170 |
+
concats=[
|
171 |
+
"input",
|
172 |
+
"masked_input",
|
173 |
+
"climategan_flood",
|
174 |
+
"stable_copy_flood",
|
175 |
+
],
|
176 |
+
as_pil_image=True,
|
177 |
+
)
|
178 |
+
|
179 |
+
input_image = output_dict["input"]
|
180 |
+
masked_input = output_dict["masked_input"]
|
181 |
+
wildfire = output_dict["wildfire"]
|
182 |
+
smog = output_dict["smog"]
|
183 |
+
depth = np.repeat(output_dict["depth"], 3, axis=-1)
|
184 |
+
segmentation = output_dict["segmentation"]
|
185 |
+
|
186 |
+
climategan_flood = output_dict.get(
|
187 |
+
"climategan_flood",
|
188 |
+
np.ones(input_image.shape, dtype=np.uint8) * 255,
|
189 |
+
)
|
190 |
+
stable_flood = output_dict.get(
|
191 |
+
"stable_flood",
|
192 |
+
np.ones(input_image.shape, dtype=np.uint8) * 255,
|
193 |
+
)
|
194 |
+
stable_copy_flood = output_dict.get(
|
195 |
+
"stable_copy_flood",
|
196 |
+
np.ones(input_image.shape, dtype=np.uint8) * 255,
|
197 |
+
)
|
198 |
+
concat = output_dict.get(
|
199 |
+
"concat",
|
200 |
+
np.ones(input_image.shape, dtype=np.uint8) * 255,
|
201 |
+
)
|
202 |
+
|
203 |
+
return (
|
204 |
+
input_image,
|
205 |
+
masked_input,
|
206 |
+
segmentation,
|
207 |
+
depth,
|
208 |
+
climategan_flood,
|
209 |
+
stable_flood,
|
210 |
+
stable_copy_flood,
|
211 |
+
concat,
|
212 |
+
wildfire,
|
213 |
+
smog,
|
214 |
+
)
|
215 |
|
216 |
return _predict
|
217 |
|
218 |
|
219 |
if __name__ == "__main__":
|
220 |
|
221 |
+
ip = get("https://api.ipify.org").content.decode("utf8")
|
222 |
+
print("My public IP address is: {}".format(ip))
|
223 |
+
|
224 |
api_key = os.environ.get("GMAPS_API_KEY")
|
225 |
gmaps = None
|
226 |
if api_key is not None:
|
227 |
gmaps = googlemaps.Client(key=api_key)
|
228 |
|
229 |
+
cg = ClimateGAN(
|
230 |
+
model_path="config/model/masker",
|
231 |
+
dev_mode=os.environ.get("CG_DEV_MODE", "").lower() == "true",
|
232 |
+
)
|
233 |
+
cg._setup_stable_diffusion()
|
234 |
+
|
235 |
+
radio = address = None
|
236 |
+
pred_ins = []
|
237 |
+
pred_outs = []
|
238 |
+
|
239 |
+
with gr.Blocks(css=CSS) as app:
|
240 |
+
with Row():
|
241 |
+
with Column():
|
242 |
+
Markdown("# ClimateGAN: Visualize Climate Change")
|
243 |
+
HTML(TEXTS[0])
|
244 |
+
with Column():
|
245 |
+
Markdown(TEXTS[1], elem_id="how-to-use-md")
|
246 |
+
with Row():
|
247 |
+
HTML("<hr><br><h2 style='font-size: 1.5rem;'>Choose Inputs</h2>")
|
248 |
+
with Row():
|
249 |
+
with Column():
|
250 |
+
if api_key:
|
251 |
+
radio = Radio(["From Address", "From Image"], label="Input Type")
|
252 |
+
pred_ins += [radio]
|
253 |
+
im_inp = Image(label="Input Image", visible=not api_key)
|
254 |
+
pred_ins += [im_inp]
|
255 |
+
if api_key:
|
256 |
+
address = Textbox(label="Address or place name", visible=False)
|
257 |
+
pred_ins += [address]
|
258 |
+
with Column():
|
259 |
+
pred_ins += [
|
260 |
+
Dropdown(
|
261 |
+
choices=[
|
262 |
+
"ClimateGAN Painter",
|
263 |
+
"Stable Diffusion Painter",
|
264 |
+
"Both",
|
265 |
+
],
|
266 |
+
label="Choose Flood Painter",
|
267 |
+
value="Both",
|
268 |
+
)
|
269 |
+
]
|
270 |
+
btn = Button(
|
271 |
+
"See for yourself!",
|
272 |
+
label="Run",
|
273 |
+
variant="primary",
|
274 |
+
visible=not api_key,
|
275 |
+
)
|
276 |
+
with Row():
|
277 |
+
Markdown("## Outputs")
|
278 |
+
with Row():
|
279 |
+
pred_outs += [Image(type="numpy", label="Original image")]
|
280 |
+
pred_outs += [Image(type="numpy", label="Masked input image")]
|
281 |
+
pred_outs += [Image(type="numpy", label="Segmentation map")]
|
282 |
+
pred_outs += [Image(type="numpy", label="Depth map")]
|
283 |
+
with Row():
|
284 |
+
pred_outs += [Image(type="numpy", label="ClimateGAN-Flooded image")]
|
285 |
+
pred_outs += [Image(type="numpy", label="Stable Diffusion-Flooded image")]
|
286 |
+
pred_outs += [
|
287 |
+
Image(
|
288 |
+
type="numpy",
|
289 |
+
label="Stable Diffusion-Flooded image (restricted to masked area)",
|
290 |
+
)
|
291 |
+
]
|
292 |
+
with Row():
|
293 |
+
pred_outs += [Image(type="numpy", label="Comparison of flood images")]
|
294 |
+
with Row():
|
295 |
+
pred_outs += [Image(type="numpy", label="Wildfire")]
|
296 |
+
pred_outs += [Image(type="numpy", label="Smog")]
|
297 |
+
Image(type="numpy", label="Empty on purpose", interactive=False)
|
298 |
+
btn.click(predict(cg, api_key), inputs=pred_ins, outputs=pred_outs)
|
299 |
+
|
300 |
+
if api_key:
|
301 |
+
radio.change(toggle, inputs=[radio], outputs=[address, im_inp, btn])
|
302 |
+
|
303 |
+
app.launch(show_api=False)
|
climategan/trainer.py
CHANGED
@@ -22,7 +22,8 @@ from torch import autograd, sigmoid, softmax
|
|
22 |
from torch.cuda.amp import GradScaler, autocast
|
23 |
from tqdm import tqdm
|
24 |
|
25 |
-
from climategan.data import get_all_loaders
|
|
|
26 |
from climategan.discriminator import OmniDiscriminator, create_discriminator
|
27 |
from climategan.eval_metrics import accuracy, mIOU
|
28 |
from climategan.fid import compute_val_fid
|
@@ -41,6 +42,7 @@ from climategan.tutils import (
|
|
41 |
print_num_parameters,
|
42 |
shuffle_batch_tuple,
|
43 |
srgb2lrgb,
|
|
|
44 |
vgg_preprocess,
|
45 |
zero_grad,
|
46 |
)
|
@@ -223,18 +225,21 @@ class Trainer:
|
|
223 |
bin_value=-1,
|
224 |
half=False,
|
225 |
xla=False,
|
226 |
-
cloudy=
|
227 |
auto_resize_640=False,
|
228 |
ignore_event=set(),
|
229 |
-
|
230 |
):
|
231 |
"""
|
232 |
-
Create a
|
233 |
single or batch image data.
|
234 |
|
235 |
-
stores is a
|
236 |
|
237 |
bin_value is used to binarize (or not) flood masks
|
|
|
|
|
|
|
238 |
"""
|
239 |
assert self.is_setup
|
240 |
assert len(x.shape) in {3, 4}, f"Unknown Data shape {x.shape}"
|
@@ -308,28 +313,39 @@ class Trainer:
|
|
308 |
if xla:
|
309 |
xm.mark_step()
|
310 |
|
|
|
|
|
311 |
if numpy:
|
312 |
with Timer(store=stores.get("numpy", [])):
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
|
334 |
return output_data
|
335 |
|
|
|
22 |
from torch.cuda.amp import GradScaler, autocast
|
23 |
from tqdm import tqdm
|
24 |
|
25 |
+
from climategan.data import get_all_loaders, decode_segmap_merged_labels
|
26 |
+
|
27 |
from climategan.discriminator import OmniDiscriminator, create_discriminator
|
28 |
from climategan.eval_metrics import accuracy, mIOU
|
29 |
from climategan.fid import compute_val_fid
|
|
|
42 |
print_num_parameters,
|
43 |
shuffle_batch_tuple,
|
44 |
srgb2lrgb,
|
45 |
+
tensor_to_uint8_numpy_image,
|
46 |
vgg_preprocess,
|
47 |
zero_grad,
|
48 |
)
|
|
|
225 |
bin_value=-1,
|
226 |
half=False,
|
227 |
xla=False,
|
228 |
+
cloudy=True,
|
229 |
auto_resize_640=False,
|
230 |
ignore_event=set(),
|
231 |
+
return_intermediates=False,
|
232 |
):
|
233 |
"""
|
234 |
+
Create a dictionary of events from a numpy or tensor,
|
235 |
single or batch image data.
|
236 |
|
237 |
+
stores is a dictionary of times for the Timer class.
|
238 |
|
239 |
bin_value is used to binarize (or not) flood masks
|
240 |
+
|
241 |
+
all values in the output dictionary have 4 dimensions:
|
242 |
+
BxHxWxC if numpy else BxCxHxW
|
243 |
"""
|
244 |
assert self.is_setup
|
245 |
assert len(x.shape) in {3, 4}, f"Unknown Data shape {x.shape}"
|
|
|
313 |
if xla:
|
314 |
xm.mark_step()
|
315 |
|
316 |
+
output_data = {}
|
317 |
+
|
318 |
if numpy:
|
319 |
with Timer(store=stores.get("numpy", [])):
|
320 |
+
if "flood" not in ignore_event:
|
321 |
+
# normalize to 0-1
|
322 |
+
flood = tensor_to_uint8_numpy_image(flood)
|
323 |
+
# convert to 0-255 uint8
|
324 |
+
output_data["flood"] = flood
|
325 |
+
if "wildfire" not in ignore_event:
|
326 |
+
wildfire = tensor_to_uint8_numpy_image(wildfire)
|
327 |
+
output_data["wildfire"] = wildfire
|
328 |
+
if "smog" not in ignore_event:
|
329 |
+
smog = tensor_to_uint8_numpy_image(smog)
|
330 |
+
output_data["smog"] = smog
|
331 |
+
|
332 |
+
if return_intermediates:
|
333 |
+
if numpy:
|
334 |
+
output_data["mask"] = (
|
335 |
+
((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
|
336 |
+
)
|
337 |
+
output_data["depth"] = tensor_to_uint8_numpy_image(depth)
|
338 |
+
output_data["segmentation"] = (
|
339 |
+
decode_segmap_merged_labels(segmentation, "r", False)
|
340 |
+
.cpu()
|
341 |
+
.permute(0, 2, 3, 1)
|
342 |
+
.numpy()
|
343 |
+
.astype(np.uint8)
|
344 |
+
)
|
345 |
+
else:
|
346 |
+
output_data["mask"] = mask
|
347 |
+
output_data["depth"] = depth
|
348 |
+
output_data["segmentation"] = segmentation
|
349 |
|
350 |
return output_data
|
351 |
|
climategan/tutils.py
CHANGED
@@ -564,14 +564,29 @@ def lrgb2srgb(ims):
|
|
564 |
return outs[0]
|
565 |
|
566 |
|
567 |
-
def normalize(t, mini=0, maxi=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
568 |
if len(t.shape) == 3:
|
569 |
return mini + (maxi - mini) * (t - t.min()) / (t.max() - t.min())
|
570 |
|
571 |
batch_size = t.shape[0]
|
572 |
-
|
|
|
573 |
t = t - min_t
|
574 |
-
max_t = t.reshape(batch_size, -1).max(1)[0].reshape(batch_size,
|
575 |
t = t / max_t
|
576 |
return mini + (maxi - mini) * t
|
577 |
|
@@ -644,7 +659,7 @@ def write_architecture(trainer):
|
|
644 |
f.write(output)
|
645 |
|
646 |
|
647 |
-
def rand_perlin_2d(shape, res, fade=lambda t: 6 * t
|
648 |
delta = (res[0] / shape[0], res[1] / shape[1])
|
649 |
d = (shape[0] // res[0], shape[1] // res[1])
|
650 |
|
@@ -719,3 +734,30 @@ def tensor_ims_to_np_uint8s(ims):
|
|
719 |
nps.append(n.astype(np.uint8))
|
720 |
|
721 |
return nps[0] if len(nps) == 1 else nps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
564 |
return outs[0]
|
565 |
|
566 |
|
567 |
+
def normalize(t, mini=0.0, maxi=1.0):
|
568 |
+
"""
|
569 |
+
Normalizes a tensor to [0, 1].
|
570 |
+
If the tensor has more than 3 dimensions, the first one
|
571 |
+
is assumed to be the batch dimension and the tensor is
|
572 |
+
normalized per batch element, not across the batches.
|
573 |
+
|
574 |
+
Args:
|
575 |
+
t (torch.Tensor): Tensor to normalize
|
576 |
+
mini (float, optional): Min allowed value. Defaults to 0.
|
577 |
+
maxi (float, optional): Max allowed value. Defaults to 1.
|
578 |
+
|
579 |
+
Returns:
|
580 |
+
torch.Tensor: The normalized tensor
|
581 |
+
"""
|
582 |
if len(t.shape) == 3:
|
583 |
return mini + (maxi - mini) * (t - t.min()) / (t.max() - t.min())
|
584 |
|
585 |
batch_size = t.shape[0]
|
586 |
+
extra_dims = [1] * (t.ndim - 1)
|
587 |
+
min_t = t.reshape(batch_size, -1).min(1)[0].reshape(batch_size, *extra_dims)
|
588 |
t = t - min_t
|
589 |
+
max_t = t.reshape(batch_size, -1).max(1)[0].reshape(batch_size, *extra_dims)
|
590 |
t = t / max_t
|
591 |
return mini + (maxi - mini) * t
|
592 |
|
|
|
659 |
f.write(output)
|
660 |
|
661 |
|
662 |
+
def rand_perlin_2d(shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
663 |
delta = (res[0] / shape[0], res[1] / shape[1])
|
664 |
d = (shape[0] // res[0], shape[1] // res[1])
|
665 |
|
|
|
734 |
nps.append(n.astype(np.uint8))
|
735 |
|
736 |
return nps[0] if len(nps) == 1 else nps
|
737 |
+
|
738 |
+
|
739 |
+
def tensor_to_uint8_numpy_image(tensor):
|
740 |
+
"""
|
741 |
+
Turns a BxCxHxW tensor into a numpy image:
|
742 |
+
* normalize
|
743 |
+
* to [0, 255]
|
744 |
+
* detach
|
745 |
+
* channels last
|
746 |
+
* to uin8
|
747 |
+
* to cpu
|
748 |
+
* to numpy
|
749 |
+
|
750 |
+
Args:
|
751 |
+
tensor (torch.Tensor): Tensor to transform
|
752 |
+
|
753 |
+
Returns:
|
754 |
+
np.array: BxHxWxC np.uint8 array in [0, 255]
|
755 |
+
"""
|
756 |
+
return (
|
757 |
+
normalize(tensor, 0, 255) # [0, 255]
|
758 |
+
.detach() # detach from graph if needed
|
759 |
+
.permute(0, 2, 3, 1) # BxHxWxC
|
760 |
+
.to(torch.uint8) # uint8
|
761 |
+
.cpu() # cpu
|
762 |
+
.numpy() # numpy array
|
763 |
+
)
|
climategan/utils.py
CHANGED
@@ -917,14 +917,14 @@ def all_texts_to_array(texts, width=640, height=40):
|
|
917 |
|
918 |
|
919 |
class Timer:
|
920 |
-
def __init__(self, name="", store=None, precision=3, ignore=False, cuda=
|
921 |
self.name = name
|
922 |
self.store = store
|
923 |
self.precision = precision
|
924 |
self.ignore = ignore
|
925 |
-
self.cuda = cuda
|
926 |
|
927 |
-
if cuda:
|
928 |
self._start_event = torch.cuda.Event(enable_timing=True)
|
929 |
self._end_event = torch.cuda.Event(enable_timing=True)
|
930 |
|
|
|
917 |
|
918 |
|
919 |
class Timer:
|
920 |
+
def __init__(self, name="", store=None, precision=3, ignore=False, cuda=None):
|
921 |
self.name = name
|
922 |
self.store = store
|
923 |
self.precision = precision
|
924 |
self.ignore = ignore
|
925 |
+
self.cuda = cuda if cuda is not None else torch.cuda.is_available()
|
926 |
|
927 |
+
if self.cuda:
|
928 |
self._start_event = torch.cuda.Event(enable_timing=True)
|
929 |
self._end_event = torch.cuda.Event(enable_timing=True)
|
930 |
|
climategan_wrapper.py
ADDED
@@ -0,0 +1,624 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/inferences.py # noqa: E501
|
2 |
+
# thank you @NimaBoscarino
|
3 |
+
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
from pathlib import Path
|
7 |
+
from uuid import uuid4
|
8 |
+
from minydra import resolved_args
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from diffusers import StableDiffusionInpaintPipeline
|
12 |
+
from PIL import Image
|
13 |
+
from skimage.color import rgba2rgb
|
14 |
+
from skimage.transform import resize
|
15 |
+
|
16 |
+
from climategan.trainer import Trainer
|
17 |
+
|
18 |
+
|
19 |
+
CUDA = torch.cuda.is_available()
|
20 |
+
|
21 |
+
|
22 |
+
def concat_events(output_dict, events, i=None, axis=1):
|
23 |
+
"""
|
24 |
+
Concatenates the `i`th data in `output_dict` according to the keys listed
|
25 |
+
in `events` on dimension `axis`.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
output_dict (dict[Union[list[np.array], np.array]]): A dictionary mapping
|
29 |
+
events to their corresponding data :
|
30 |
+
{k: [HxWxC]} (for i != None) or {k: BxHxWxC}.
|
31 |
+
events (list[str]): output_dict's keys to concatenate.
|
32 |
+
axis (int, optional): Concatenation axis. Defaults to 1.
|
33 |
+
"""
|
34 |
+
cs = [e for e in events if e in output_dict]
|
35 |
+
if i is not None:
|
36 |
+
return uint8(np.concatenate([output_dict[c][i] for c in cs], axis=axis))
|
37 |
+
return uint8(np.concatenate([output_dict[c] for c in cs], axis=axis))
|
38 |
+
|
39 |
+
|
40 |
+
def clear(folder):
|
41 |
+
"""
|
42 |
+
Deletes all the images without the inference separator "---" in their name.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
folder (Union[str, Path]): The folder to clear.
|
46 |
+
"""
|
47 |
+
for i in list(Path(folder).iterdir()):
|
48 |
+
if i.is_file() and "---" in i.stem:
|
49 |
+
i.unlink()
|
50 |
+
|
51 |
+
|
52 |
+
def uint8(array, rescale=False):
|
53 |
+
"""
|
54 |
+
convert an array to np.uint8 (does not rescale or anything else than changing dtype)
|
55 |
+
Args:
|
56 |
+
array (np.array): array to modify
|
57 |
+
Returns:
|
58 |
+
np.array(np.uint8): converted array
|
59 |
+
"""
|
60 |
+
if rescale:
|
61 |
+
if array.min() < 0:
|
62 |
+
if array.min() >= -1 and array.max() <= 1:
|
63 |
+
array = (array + 1) / 2
|
64 |
+
else:
|
65 |
+
raise ValueError(
|
66 |
+
f"Data range mismatch for image: ({array.min()}, {array.max()})"
|
67 |
+
)
|
68 |
+
if array.max() <= 1:
|
69 |
+
array = array * 255
|
70 |
+
return array.astype(np.uint8)
|
71 |
+
|
72 |
+
|
73 |
+
def resize_and_crop(img, to=640):
|
74 |
+
"""
|
75 |
+
Resizes an image so that it keeps the aspect ratio and the smallest dimensions
|
76 |
+
is `to`, then crops this resized image in its center so that the output is `to x to`
|
77 |
+
without aspect ratio distortion
|
78 |
+
Args:
|
79 |
+
img (np.array): np.uint8 255 image
|
80 |
+
Returns:
|
81 |
+
np.array: [0, 1] np.float32 image
|
82 |
+
"""
|
83 |
+
# resize keeping aspect ratio: smallest dim is 640
|
84 |
+
h, w = img.shape[:2]
|
85 |
+
if h < w:
|
86 |
+
size = (to, int(to * w / h))
|
87 |
+
else:
|
88 |
+
size = (int(to * h / w), to)
|
89 |
+
|
90 |
+
r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
|
91 |
+
r_img = uint8(r_img)
|
92 |
+
|
93 |
+
# crop in the center
|
94 |
+
H, W = r_img.shape[:2]
|
95 |
+
|
96 |
+
top = (H - to) // 2
|
97 |
+
left = (W - to) // 2
|
98 |
+
|
99 |
+
rc_img = r_img[top : top + to, left : left + to, :]
|
100 |
+
|
101 |
+
return rc_img / 255.0
|
102 |
+
|
103 |
+
|
104 |
+
def to_m1_p1(img):
|
105 |
+
"""
|
106 |
+
rescales a [0, 1] image to [-1, +1]
|
107 |
+
Args:
|
108 |
+
img (np.array): float32 numpy array of an image in [0, 1]
|
109 |
+
i (int): Index of the image being rescaled
|
110 |
+
Raises:
|
111 |
+
ValueError: If the image is not in [0, 1]
|
112 |
+
Returns:
|
113 |
+
np.array(np.float32): array in [-1, +1]
|
114 |
+
"""
|
115 |
+
if img.min() >= 0 and img.max() <= 1:
|
116 |
+
return (img.astype(np.float32) - 0.5) * 2
|
117 |
+
raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})")
|
118 |
+
|
119 |
+
|
120 |
+
# No need to do any timing in this, since it's just for the HF Space
|
121 |
+
class ClimateGAN:
|
122 |
+
def __init__(self, model_path, dev_mode=False) -> None:
|
123 |
+
"""
|
124 |
+
A wrapper for the ClimateGAN model that you can use to generate
|
125 |
+
events from images or folders containing images.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
model_path (Union[str, Path]): Where to load the Masker from
|
129 |
+
"""
|
130 |
+
torch.set_grad_enabled(False)
|
131 |
+
self.target_size = 640
|
132 |
+
self._stable_diffusion_is_setup = False
|
133 |
+
self.dev_mode = dev_mode
|
134 |
+
if self.dev_mode:
|
135 |
+
return
|
136 |
+
self.trainer = Trainer.resume_from_path(
|
137 |
+
model_path,
|
138 |
+
setup=True,
|
139 |
+
inference=True,
|
140 |
+
new_exp=None,
|
141 |
+
)
|
142 |
+
if CUDA:
|
143 |
+
self.trainer.G.half()
|
144 |
+
|
145 |
+
def _setup_stable_diffusion(self):
|
146 |
+
"""
|
147 |
+
Sets up the stable diffusion pipeline for in-painting.
|
148 |
+
Make sure you have accepted the license on the model's card
|
149 |
+
https://huggingface.co/CompVis/stable-diffusion-v1-4
|
150 |
+
"""
|
151 |
+
if self.dev_mode:
|
152 |
+
return
|
153 |
+
|
154 |
+
try:
|
155 |
+
self.sdip_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
156 |
+
"runwayml/stable-diffusion-inpainting",
|
157 |
+
revision="fp16" if CUDA else "main",
|
158 |
+
torch_dtype=torch.float16 if CUDA else torch.float32,
|
159 |
+
safety_checker=None,
|
160 |
+
use_auth_token=os.environ.get("HF_AUTH_TOKEN"),
|
161 |
+
).to(self.trainer.device)
|
162 |
+
self._stable_diffusion_is_setup = True
|
163 |
+
except Exception as e:
|
164 |
+
print(
|
165 |
+
"\nCould not load stable diffusion model. "
|
166 |
+
+ "Please make sure you have accepted the license on the model's"
|
167 |
+
+ " card https://huggingface.co/CompVis/stable-diffusion-v1-4\n"
|
168 |
+
)
|
169 |
+
raise e
|
170 |
+
|
171 |
+
def _preprocess_image(self, img):
|
172 |
+
"""
|
173 |
+
Turns a HxWxC uint8 numpy array into a 640x640x3 float32 numpy array
|
174 |
+
in [-1, 1].
|
175 |
+
|
176 |
+
Args:
|
177 |
+
img (np.array): Image to resize crop and rescale
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
np.array: Resized, cropped and rescaled image
|
181 |
+
"""
|
182 |
+
# rgba to rgb
|
183 |
+
data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255)
|
184 |
+
|
185 |
+
# to args.target_size
|
186 |
+
data = resize_and_crop(data, self.target_size)
|
187 |
+
|
188 |
+
# resize() produces [0, 1] images, rescale to [-1, 1]
|
189 |
+
data = to_m1_p1(data)
|
190 |
+
return data
|
191 |
+
|
192 |
+
# Does all three inferences at the moment.
|
193 |
+
def infer_single(
|
194 |
+
self,
|
195 |
+
orig_image,
|
196 |
+
painter="both",
|
197 |
+
prompt="An HD picture of a street with dirty water after a heavy flood",
|
198 |
+
concats=[
|
199 |
+
"input",
|
200 |
+
"masked_input",
|
201 |
+
"climategan_flood",
|
202 |
+
"stable_flood",
|
203 |
+
"stable_copy_flood",
|
204 |
+
],
|
205 |
+
as_pil_image=False,
|
206 |
+
):
|
207 |
+
"""
|
208 |
+
Infers the image with the ClimateGAN model.
|
209 |
+
Importantly (and unlike self.infer_preprocessed_batch), the image is
|
210 |
+
pre-processed by self._preprocess_image before going through the networks.
|
211 |
+
|
212 |
+
Output dict contains the following keys:
|
213 |
+
- "input": The input image
|
214 |
+
- "mask": The mask used to generate the flood (from ClimateGAN's Masker)
|
215 |
+
- "masked_input": The input image with the mask applied
|
216 |
+
- "climategan_flood": The flooded image generated by ClimateGAN's Painter
|
217 |
+
on the masked input (only if "painter" is "climategan" or "both").
|
218 |
+
- "stable_flood": The flooded image in-painted by the stable diffusion model
|
219 |
+
from the mask and the input image (only if "painter" is "stable_diffusion"
|
220 |
+
or "both").
|
221 |
+
- "stable_copy_flood": The flooded image in-painted by the stable diffusion
|
222 |
+
model with its original context pasted back in:
|
223 |
+
y = m * flooded + (1-m) * input
|
224 |
+
(only if "painter" is "stable_diffusion" or "both").
|
225 |
+
|
226 |
+
Args:
|
227 |
+
orig_image (Union[str, np.array]): image to infer on. Can be a path to
|
228 |
+
an image which will be read.
|
229 |
+
painter (str, optional): Which painter to use: "climategan",
|
230 |
+
"stable_diffusion" or "both". Defaults to "both".
|
231 |
+
prompt (str, optional): The prompt used to guide the diffusion. Defaults
|
232 |
+
to "An HD picture of a street with dirty water after a heavy flood".
|
233 |
+
concats (list, optional): List of keys in `output` to concatenate together
|
234 |
+
in a new `{original_stem}_concat` image written. Defaults to:
|
235 |
+
["input", "masked_input", "climategan_flood", "stable_flood",
|
236 |
+
"stable_copy_flood"].
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
dict: a dictionary containing the output images {k: HxWxC}. C is omitted
|
240 |
+
for masks (HxW).
|
241 |
+
"""
|
242 |
+
if self.dev_mode:
|
243 |
+
return {
|
244 |
+
"input": orig_image,
|
245 |
+
"mask": np.random.randint(0, 255, (640, 640)),
|
246 |
+
"masked_input": np.random.randint(0, 255, (640, 640, 3)),
|
247 |
+
"climategan_flood": np.random.randint(0, 255, (640, 640, 3)),
|
248 |
+
"stable_flood": np.random.randint(0, 255, (640, 640, 3)),
|
249 |
+
"stable_copy_flood": np.random.randint(0, 255, (640, 640, 3)),
|
250 |
+
"concat": np.random.randint(0, 255, (640, 640 * 5, 3)),
|
251 |
+
"smog": np.random.randint(0, 255, (640, 640, 3)),
|
252 |
+
"wildfire": np.random.randint(0, 255, (640, 640, 3)),
|
253 |
+
"depth": np.random.randint(0, 255, (640, 640, 1)),
|
254 |
+
"segmentation": np.random.randint(0, 255, (640, 640, 3)),
|
255 |
+
}
|
256 |
+
return
|
257 |
+
|
258 |
+
image_array = (
|
259 |
+
np.array(Image.open(orig_image))
|
260 |
+
if isinstance(orig_image, str)
|
261 |
+
else orig_image
|
262 |
+
)
|
263 |
+
|
264 |
+
pil_image = None
|
265 |
+
if as_pil_image:
|
266 |
+
pil_image = Image.fromarray(image_array)
|
267 |
+
print("Preprocessing image")
|
268 |
+
image = self._preprocess_image(image_array)
|
269 |
+
output_dict = self.infer_preprocessed_batch(
|
270 |
+
images=image[None, ...],
|
271 |
+
painter=painter,
|
272 |
+
prompt=prompt,
|
273 |
+
concats=concats,
|
274 |
+
pil_image=pil_image,
|
275 |
+
)
|
276 |
+
print("Inference done")
|
277 |
+
return {k: v[0] for k, v in output_dict.items()}
|
278 |
+
|
279 |
+
def infer_preprocessed_batch(
|
280 |
+
self,
|
281 |
+
images,
|
282 |
+
painter="both",
|
283 |
+
prompt="An HD picture of a street with dirty water after a heavy flood",
|
284 |
+
concats=[
|
285 |
+
"input",
|
286 |
+
"masked_input",
|
287 |
+
"climategan_flood",
|
288 |
+
"stable_flood",
|
289 |
+
"stable_copy_flood",
|
290 |
+
],
|
291 |
+
pil_image=None,
|
292 |
+
):
|
293 |
+
"""
|
294 |
+
Infers ClimateGAN predictions on a batch of preprocessed images.
|
295 |
+
It assumes that each image in the batch has been preprocessed with
|
296 |
+
self._preprocess_image().
|
297 |
+
|
298 |
+
Output dict contains the following keys:
|
299 |
+
- "input": The input image
|
300 |
+
- "mask": The mask used to generate the flood (from ClimateGAN's Masker)
|
301 |
+
- "masked_input": The input image with the mask applied
|
302 |
+
- "climategan_flood": The flooded image generated by ClimateGAN's Painter
|
303 |
+
on the masked input (only if "painter" is "climategan" or "both").
|
304 |
+
- "stable_flood": The flooded image in-painted by the stable diffusion model
|
305 |
+
from the mask and the input image (only if "painter" is "stable_diffusion"
|
306 |
+
or "both").
|
307 |
+
- "stable_copy_flood": The flooded image in-painted by the stable diffusion
|
308 |
+
model with its original context pasted back in:
|
309 |
+
y = m * flooded + (1-m) * input
|
310 |
+
(only if "painter" is "stable_diffusion" or "both").
|
311 |
+
|
312 |
+
Args:
|
313 |
+
images (np.array): A batch of input images BxHxWx3
|
314 |
+
painter (str, optional): Which painter to use: "climategan",
|
315 |
+
"stable_diffusion" or "both". Defaults to "both".
|
316 |
+
prompt (str, optional): The prompt used to guide the diffusion. Defaults
|
317 |
+
to "An HD picture of a street with dirty water after a heavy flood".
|
318 |
+
concats (list, optional): List of keys in `output` to concatenate together
|
319 |
+
in a new `{original_stem}_concat` image written. Defaults to:
|
320 |
+
["input", "masked_input", "climategan_flood", "stable_flood",
|
321 |
+
"stable_copy_flood"].
|
322 |
+
pil_image (PIL.Image, optional): The original PIL image. If provided,
|
323 |
+
will be used for a single inference (batch_size=1)
|
324 |
+
|
325 |
+
Returns:
|
326 |
+
dict: a dictionary containing the output images
|
327 |
+
"""
|
328 |
+
assert painter in [
|
329 |
+
"both",
|
330 |
+
"stable_diffusion",
|
331 |
+
"climategan",
|
332 |
+
], f"Unknown painter: {painter}"
|
333 |
+
|
334 |
+
ignore_event = set()
|
335 |
+
if painter == "stable_diffusion":
|
336 |
+
ignore_event.add("flood")
|
337 |
+
|
338 |
+
if pil_image is not None:
|
339 |
+
print("Warning: `pil_image` has been provided, it will override `images`")
|
340 |
+
images = self._preprocess_image(np.array(pil_image))[None, ...]
|
341 |
+
pil_image = Image.fromarray(((images[0] + 1) / 2 * 255).astype(np.uint8))
|
342 |
+
|
343 |
+
# Retrieve numpy events as a dict {event: array[BxHxWxC]}
|
344 |
+
print("Inferring ClimateGAN events")
|
345 |
+
outputs = self.trainer.infer_all(
|
346 |
+
images,
|
347 |
+
numpy=True,
|
348 |
+
bin_value=0.5,
|
349 |
+
half=CUDA,
|
350 |
+
ignore_event=ignore_event,
|
351 |
+
return_intermediates=True,
|
352 |
+
)
|
353 |
+
|
354 |
+
outputs["input"] = uint8(images, True)
|
355 |
+
# from Bx1xHxW to BxHxWx1
|
356 |
+
outputs["masked_input"] = outputs["input"] * (
|
357 |
+
outputs["mask"].squeeze(1)[..., None] == 0
|
358 |
+
)
|
359 |
+
|
360 |
+
if painter in {"both", "climategan"}:
|
361 |
+
outputs["climategan_flood"] = outputs.pop("flood")
|
362 |
+
else:
|
363 |
+
del outputs["flood"]
|
364 |
+
|
365 |
+
if painter != "climategan":
|
366 |
+
if not self._stable_diffusion_is_setup:
|
367 |
+
print("Setting up stable diffusion in-painting pipeline")
|
368 |
+
self._setup_stable_diffusion()
|
369 |
+
|
370 |
+
mask = outputs["mask"].squeeze(1)
|
371 |
+
input_images = (
|
372 |
+
torch.tensor(images).permute(0, 3, 1, 2).to(self.trainer.device)
|
373 |
+
if pil_image is None
|
374 |
+
else pil_image
|
375 |
+
)
|
376 |
+
input_mask = (
|
377 |
+
torch.tensor(mask[:, None, ...] > 0).to(self.trainer.device)
|
378 |
+
if pil_image is None
|
379 |
+
else Image.fromarray(mask[0])
|
380 |
+
)
|
381 |
+
print("Inferring stable diffusion in-painting for 50 steps")
|
382 |
+
floods = self.sdip_pipeline(
|
383 |
+
prompt=[prompt] * images.shape[0],
|
384 |
+
image=input_images,
|
385 |
+
mask_image=input_mask,
|
386 |
+
height=640,
|
387 |
+
width=640,
|
388 |
+
num_inference_steps=50,
|
389 |
+
)
|
390 |
+
print("Stable diffusion in-painting done")
|
391 |
+
|
392 |
+
bin_mask = mask[..., None] > 0
|
393 |
+
flood = np.stack([np.array(i) for i in floods.images])
|
394 |
+
copy_flood = flood * bin_mask + uint8(images, True) * (1 - bin_mask)
|
395 |
+
outputs["stable_flood"] = flood
|
396 |
+
outputs["stable_copy_flood"] = copy_flood
|
397 |
+
|
398 |
+
if concats:
|
399 |
+
print("Concatenating flood images")
|
400 |
+
outputs["concat"] = concat_events(outputs, concats, axis=2)
|
401 |
+
|
402 |
+
return {k: v.squeeze(1) if v.shape[1] == 1 else v for k, v in outputs.items()}
|
403 |
+
|
404 |
+
def infer_folder(
|
405 |
+
self,
|
406 |
+
folder_path,
|
407 |
+
painter="both",
|
408 |
+
prompt="An HD picture of a street with dirty water after a heavy flood",
|
409 |
+
batch_size=4,
|
410 |
+
concats=[
|
411 |
+
"input",
|
412 |
+
"masked_input",
|
413 |
+
"climategan_flood",
|
414 |
+
"stable_flood",
|
415 |
+
"stable_copy_flood",
|
416 |
+
],
|
417 |
+
write=True,
|
418 |
+
overwrite=False,
|
419 |
+
):
|
420 |
+
"""
|
421 |
+
Infers the images in a folder with the ClimateGAN model, batching images for
|
422 |
+
inference according to the batch_size.
|
423 |
+
|
424 |
+
Images must end in .jpg, .jpeg or .png (not case-sensitive).
|
425 |
+
Images must not contain the separator ("---") in their name.
|
426 |
+
|
427 |
+
Images will be written to disk in the same folder as the input images, with
|
428 |
+
a name that depends on its data, potentially the prompt and a random
|
429 |
+
identifier in case multiple inferences are run in the folder.
|
430 |
+
|
431 |
+
Output dict contains the following keys:
|
432 |
+
- "input": The input image
|
433 |
+
- "mask": The mask used to generate the flood (from ClimateGAN's Masker)
|
434 |
+
- "masked_input": The input image with the mask applied
|
435 |
+
- "climategan_flood": The flooded image generated by ClimateGAN's Painter
|
436 |
+
on the masked input (only if "painter" is "climategan" or "both").
|
437 |
+
- "stable_flood": The flooded image in-painted by the stable diffusion model
|
438 |
+
from the mask and the input image (only if "painter" is "stable_diffusion"
|
439 |
+
or "both").
|
440 |
+
- "stable_copy_flood": The flooded image in-painted by the stable diffusion
|
441 |
+
model with its original context pasted back in:
|
442 |
+
y = m * flooded + (1-m) * input
|
443 |
+
(only if "painter" is "stable_diffusion" or "both").
|
444 |
+
|
445 |
+
Args:
|
446 |
+
folder_path (Union[str, Path]): Where to read images from.
|
447 |
+
painter (str, optional): Which painter to use: "climategan",
|
448 |
+
"stable_diffusion" or "both". Defaults to "both".
|
449 |
+
prompt (str, optional): The prompt used to guide the diffusion. Defaults
|
450 |
+
to "An HD picture of a street with dirty water after a heavy flood".
|
451 |
+
batch_size (int, optional): Size of inference batches. Defaults to 4.
|
452 |
+
concats (list, optional): List of keys in `output` to concatenate together
|
453 |
+
in a new `{original_stem}_concat` image written. Defaults to:
|
454 |
+
["input", "masked_input", "climategan_flood", "stable_flood",
|
455 |
+
"stable_copy_flood"].
|
456 |
+
write (bool, optional): Whether or not to write the outputs to the input
|
457 |
+
folder.Defaults to True.
|
458 |
+
overwrite (Union[bool, str], optional): Whether to overwrite the images or
|
459 |
+
not. If a string is provided, it will be included in the name.
|
460 |
+
Defaults to False.
|
461 |
+
|
462 |
+
Returns:
|
463 |
+
dict: a dictionary containing the output images
|
464 |
+
"""
|
465 |
+
folder_path = Path(folder_path).expanduser().resolve()
|
466 |
+
assert folder_path.exists(), f"Folder {str(folder_path)} does not exist"
|
467 |
+
assert folder_path.is_dir(), f"{str(folder_path)} is not a directory"
|
468 |
+
im_paths = [
|
469 |
+
p
|
470 |
+
for p in folder_path.iterdir()
|
471 |
+
if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name
|
472 |
+
]
|
473 |
+
assert im_paths, f"No images found in {str(folder_path)}"
|
474 |
+
ims = [self._preprocess_image(np.array(Image.open(p))) for p in im_paths]
|
475 |
+
batches = [
|
476 |
+
np.stack(ims[i : i + batch_size]) for i in range(0, len(ims), batch_size)
|
477 |
+
]
|
478 |
+
inferences = [
|
479 |
+
self.infer_preprocessed_batch(b, painter, prompt, concats) for b in batches
|
480 |
+
]
|
481 |
+
|
482 |
+
outputs = {
|
483 |
+
k: [i for e in inferences for i in e[k]] for k in inferences[0].keys()
|
484 |
+
}
|
485 |
+
|
486 |
+
if write:
|
487 |
+
self.write(outputs, im_paths, painter, overwrite, prompt)
|
488 |
+
|
489 |
+
return outputs
|
490 |
+
|
491 |
+
def write(
|
492 |
+
self,
|
493 |
+
outputs,
|
494 |
+
im_paths,
|
495 |
+
painter="both",
|
496 |
+
overwrite=False,
|
497 |
+
prompt="",
|
498 |
+
):
|
499 |
+
"""
|
500 |
+
Writes the outputs of the inference to disk, in the input folder.
|
501 |
+
|
502 |
+
Images will be named like:
|
503 |
+
f"{original_stem}---{overwrite_prefix}_{painter_type}_{output_type}.{suffix}"
|
504 |
+
`painter_type` is either "climategan" or f"stable_diffusion_{prompt}"
|
505 |
+
|
506 |
+
Args:
|
507 |
+
outputs (_type_): The inference procedure's output dict.
|
508 |
+
im_paths (list[Path]): The list of input images paths.
|
509 |
+
painter (str, optional): Which painter was used. Defaults to "both".
|
510 |
+
overwrite (bool, optional): Whether to overwrite the images or not.
|
511 |
+
If a string is provided, it will be included in the name.
|
512 |
+
If False, a random identifier will be added to the name.
|
513 |
+
Defaults to False.
|
514 |
+
prompt (str, optional): The prompt used to guide the diffusion. Defaults
|
515 |
+
to "".
|
516 |
+
"""
|
517 |
+
prompt = re.sub("[^0-9a-zA-Z]+", "", prompt).lower()
|
518 |
+
overwrite_prefix = ""
|
519 |
+
if not overwrite:
|
520 |
+
overwrite_prefix = str(uuid4())[:8]
|
521 |
+
print("Writing events with prefix", overwrite_prefix)
|
522 |
+
else:
|
523 |
+
if isinstance(overwrite, str):
|
524 |
+
overwrite_prefix = overwrite
|
525 |
+
print("Writing events with prefix", overwrite_prefix)
|
526 |
+
|
527 |
+
# for each image, for each event/data type
|
528 |
+
for i, im_path in enumerate(im_paths):
|
529 |
+
for event, ims in outputs.items():
|
530 |
+
painter_prefix = ""
|
531 |
+
if painter == "climategan" and event == "flood":
|
532 |
+
painter_prefix = "climategan"
|
533 |
+
elif (
|
534 |
+
painter in {"stable_diffusion", "both"} and event == "stable_flood"
|
535 |
+
):
|
536 |
+
painter_prefix = f"_stable_{prompt}"
|
537 |
+
elif painter == "both" and event == "climategan_flood":
|
538 |
+
painter_prefix = ""
|
539 |
+
|
540 |
+
im = ims[i]
|
541 |
+
im = Image.fromarray(uint8(im))
|
542 |
+
imstem = f"{im_path.stem}---{overwrite_prefix}{painter_prefix}_{event}"
|
543 |
+
im.save(im_path.parent / (imstem + im_path.suffix))
|
544 |
+
|
545 |
+
|
546 |
+
if __name__ == "__main__":
|
547 |
+
print("Run `$ python climategan_wrapper.py help` for usage instructions\n")
|
548 |
+
|
549 |
+
# parse arguments
|
550 |
+
args = resolved_args(
|
551 |
+
defaults={
|
552 |
+
"input_folder": None,
|
553 |
+
"output_folder": None,
|
554 |
+
"painter": "both",
|
555 |
+
"help": False,
|
556 |
+
}
|
557 |
+
)
|
558 |
+
|
559 |
+
# print help
|
560 |
+
if args.help:
|
561 |
+
print(
|
562 |
+
"Usage: python inference.py input_folder=/path/to/folder\n"
|
563 |
+
+ "By default inferences will be stored in the input folder.\n"
|
564 |
+
+ "Add `output_folder=/path/to/folder` for a different output folder.\n"
|
565 |
+
+ "By default, both ClimateGAN and Stable Diffusion will be used."
|
566 |
+
+ "Change this by adding `painter=climategan` or"
|
567 |
+
+ " `painter=stable_diffusion`.\n"
|
568 |
+
+ "Make sure you have agreed to the terms of use for the models."
|
569 |
+
+ "In particular, visit SD's model card to agree to the terms of use:"
|
570 |
+
+ " https://huggingface.co/runwayml/stable-diffusion-inpainting"
|
571 |
+
)
|
572 |
+
# print args
|
573 |
+
args.pretty_print()
|
574 |
+
|
575 |
+
# load models
|
576 |
+
cg = ClimateGAN("models/climategan")
|
577 |
+
|
578 |
+
# check painter type
|
579 |
+
assert args.painter in {"climategan", "stable_diffusion", "both",}, (
|
580 |
+
f"Unknown painter {args.painter}. "
|
581 |
+
+ "Allowed values are 'climategan', 'stable_diffusion' and 'both'."
|
582 |
+
)
|
583 |
+
|
584 |
+
# load SD pipeline if need be
|
585 |
+
if args.painter != "climate_gan":
|
586 |
+
cg._setup_stable_diffusion()
|
587 |
+
|
588 |
+
# resolve input folder path
|
589 |
+
in_path = Path(args.input_folder).expanduser().resolve()
|
590 |
+
assert in_path.exists(), f"Folder {str(in_path)} does not exist"
|
591 |
+
|
592 |
+
# output is input if not specified
|
593 |
+
if args.output_folder is None:
|
594 |
+
out_path = in_path
|
595 |
+
|
596 |
+
# find images in input folder
|
597 |
+
im_paths = [
|
598 |
+
p
|
599 |
+
for p in in_path.iterdir()
|
600 |
+
if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name
|
601 |
+
]
|
602 |
+
assert im_paths, f"No images found in {str(im_paths)}"
|
603 |
+
|
604 |
+
print(f"\nFound {len(im_paths)} images in {str(in_path)}\n")
|
605 |
+
|
606 |
+
# infer and write
|
607 |
+
for i, im_path in enumerate(im_paths):
|
608 |
+
print(">>> Processing", f"{i}/{len(im_paths)}", im_path.name)
|
609 |
+
outs = cg.infer_single(
|
610 |
+
np.array(Image.open(im_path)),
|
611 |
+
args.painter,
|
612 |
+
as_pil_image=True,
|
613 |
+
concats=[
|
614 |
+
"input",
|
615 |
+
"masked_input",
|
616 |
+
"climategan_flood",
|
617 |
+
"stable_copy_flood",
|
618 |
+
],
|
619 |
+
)
|
620 |
+
for k, v in outs.items():
|
621 |
+
name = f"{im_path.stem}---{k}{im_path.suffix}"
|
622 |
+
im = Image.fromarray(uint8(v))
|
623 |
+
im.save(out_path / name)
|
624 |
+
print(">>> Done", f"{i}/{len(im_paths)}", im_path.name, end="\n\n")
|
inferences.py
DELETED
@@ -1,108 +0,0 @@
|
|
1 |
-
# based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/inferences.py # noqa: E501
|
2 |
-
# thank you @NimaBoscarino
|
3 |
-
|
4 |
-
import torch
|
5 |
-
from skimage.color import rgba2rgb
|
6 |
-
from skimage.transform import resize
|
7 |
-
import numpy as np
|
8 |
-
|
9 |
-
from climategan.trainer import Trainer
|
10 |
-
|
11 |
-
|
12 |
-
def uint8(array):
|
13 |
-
"""
|
14 |
-
convert an array to np.uint8 (does not rescale or anything else than changing dtype)
|
15 |
-
Args:
|
16 |
-
array (np.array): array to modify
|
17 |
-
Returns:
|
18 |
-
np.array(np.uint8): converted array
|
19 |
-
"""
|
20 |
-
return array.astype(np.uint8)
|
21 |
-
|
22 |
-
|
23 |
-
def resize_and_crop(img, to=640):
|
24 |
-
"""
|
25 |
-
Resizes an image so that it keeps the aspect ratio and the smallest dimensions
|
26 |
-
is `to`, then crops this resized image in its center so that the output is `to x to`
|
27 |
-
without aspect ratio distortion
|
28 |
-
Args:
|
29 |
-
img (np.array): np.uint8 255 image
|
30 |
-
Returns:
|
31 |
-
np.array: [0, 1] np.float32 image
|
32 |
-
"""
|
33 |
-
# resize keeping aspect ratio: smallest dim is 640
|
34 |
-
h, w = img.shape[:2]
|
35 |
-
if h < w:
|
36 |
-
size = (to, int(to * w / h))
|
37 |
-
else:
|
38 |
-
size = (int(to * h / w), to)
|
39 |
-
|
40 |
-
r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
|
41 |
-
r_img = uint8(r_img)
|
42 |
-
|
43 |
-
# crop in the center
|
44 |
-
H, W = r_img.shape[:2]
|
45 |
-
|
46 |
-
top = (H - to) // 2
|
47 |
-
left = (W - to) // 2
|
48 |
-
|
49 |
-
rc_img = r_img[top : top + to, left : left + to, :]
|
50 |
-
|
51 |
-
return rc_img / 255.0
|
52 |
-
|
53 |
-
|
54 |
-
def to_m1_p1(img):
|
55 |
-
"""
|
56 |
-
rescales a [0, 1] image to [-1, +1]
|
57 |
-
Args:
|
58 |
-
img (np.array): float32 numpy array of an image in [0, 1]
|
59 |
-
i (int): Index of the image being rescaled
|
60 |
-
Raises:
|
61 |
-
ValueError: If the image is not in [0, 1]
|
62 |
-
Returns:
|
63 |
-
np.array(np.float32): array in [-1, +1]
|
64 |
-
"""
|
65 |
-
if img.min() >= 0 and img.max() <= 1:
|
66 |
-
return (img.astype(np.float32) - 0.5) * 2
|
67 |
-
raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})")
|
68 |
-
|
69 |
-
|
70 |
-
# No need to do any timing in this, since it's just for the HF Space
|
71 |
-
class ClimateGAN:
|
72 |
-
def __init__(self, model_path) -> None:
|
73 |
-
torch.set_grad_enabled(False)
|
74 |
-
self.target_size = 640
|
75 |
-
self.trainer = Trainer.resume_from_path(
|
76 |
-
model_path,
|
77 |
-
setup=True,
|
78 |
-
inference=True,
|
79 |
-
new_exp=None,
|
80 |
-
)
|
81 |
-
|
82 |
-
# Does all three inferences at the moment.
|
83 |
-
def inference(self, orig_image):
|
84 |
-
image = self._preprocess_image(orig_image)
|
85 |
-
|
86 |
-
# Retrieve numpy events as a dict {event: array[BxHxWxC]}
|
87 |
-
outputs = self.trainer.infer_all(
|
88 |
-
image,
|
89 |
-
numpy=True,
|
90 |
-
bin_value=0.5,
|
91 |
-
)
|
92 |
-
|
93 |
-
return (
|
94 |
-
outputs["flood"].squeeze(),
|
95 |
-
outputs["wildfire"].squeeze(),
|
96 |
-
outputs["smog"].squeeze(),
|
97 |
-
)
|
98 |
-
|
99 |
-
def _preprocess_image(self, img):
|
100 |
-
# rgba to rgb
|
101 |
-
data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255)
|
102 |
-
|
103 |
-
# to args.target_size
|
104 |
-
data = resize_and_crop(data, self.target_size)
|
105 |
-
|
106 |
-
# resize() produces [0, 1] images, rescale to [-1, 1]
|
107 |
-
data = to_m1_p1(data)
|
108 |
-
return data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
addict==2.4.0
|
2 |
+
aiohttp==3.8.3
|
3 |
+
aiosignal==1.2.0
|
4 |
+
anyio==3.6.2
|
5 |
+
appnope==0.1.3
|
6 |
+
APScheduler==3.7.0
|
7 |
+
asttokens==2.0.8
|
8 |
+
async-timeout==4.0.2
|
9 |
+
attrs==21.2.0
|
10 |
+
backcall==0.2.0
|
11 |
+
bcrypt==4.0.1
|
12 |
+
black==22.10.0
|
13 |
+
blis==0.7.9
|
14 |
+
Brotli==1.0.9
|
15 |
+
catalogue==2.0.8
|
16 |
+
certifi==2021.5.30
|
17 |
+
cffi==1.15.1
|
18 |
+
charset-normalizer==2.0.4
|
19 |
+
click==8.0.1
|
20 |
+
codecarbon==1.2.0
|
21 |
+
comet-ml==3.15.3
|
22 |
+
commonmark==0.9.1
|
23 |
+
confection==0.0.3
|
24 |
+
configobj==5.0.6
|
25 |
+
contourpy==1.0.5
|
26 |
+
cryptography==38.0.1
|
27 |
+
cycler==0.10.0
|
28 |
+
cymem==2.0.7
|
29 |
+
dash==2.0.0
|
30 |
+
dash-bootstrap-components==0.13.0
|
31 |
+
dash-core-components==2.0.0
|
32 |
+
dash-html-components==2.0.0
|
33 |
+
dash-table==5.0.0
|
34 |
+
dataclasses==0.6
|
35 |
+
decorator==5.0.9
|
36 |
+
diffusers==0.6.0
|
37 |
+
dulwich==0.20.25
|
38 |
+
everett==2.0.1
|
39 |
+
executing==1.1.1
|
40 |
+
fastapi==0.85.1
|
41 |
+
ffmpy==0.3.0
|
42 |
+
filelock==3.0.12
|
43 |
+
fire==0.4.0
|
44 |
+
flake8==5.0.4
|
45 |
+
Flask==2.0.1
|
46 |
+
Flask-Compress==1.10.1
|
47 |
+
fonttools==4.38.0
|
48 |
+
frozenlist==1.3.1
|
49 |
+
fsspec==2022.10.0
|
50 |
+
future==0.18.2
|
51 |
+
gdown==3.13.0
|
52 |
+
googlemaps==4.6.0
|
53 |
+
gradio==3.6
|
54 |
+
h11==0.12.0
|
55 |
+
httpcore==0.15.0
|
56 |
+
httpx==0.23.0
|
57 |
+
huggingface-hub==0.10.1
|
58 |
+
hydra-core==0.11.3
|
59 |
+
idna==3.2
|
60 |
+
imageio==2.9.0
|
61 |
+
importlib-metadata==5.0.0
|
62 |
+
ipython==7.27.0
|
63 |
+
itsdangerous==2.0.1
|
64 |
+
jedi==0.18.0
|
65 |
+
Jinja2==3.0.1
|
66 |
+
joblib==1.0.1
|
67 |
+
jsonschema==3.2.0
|
68 |
+
kiwisolver==1.3.2
|
69 |
+
kornia==0.5.10
|
70 |
+
langcodes==3.3.0
|
71 |
+
linkify-it-py==1.0.3
|
72 |
+
markdown-it-py==2.1.0
|
73 |
+
MarkupSafe==2.0.1
|
74 |
+
matplotlib==3.4.3
|
75 |
+
matplotlib-inline==0.1.2
|
76 |
+
mccabe==0.7.0
|
77 |
+
mdit-py-plugins==0.3.1
|
78 |
+
mdurl==0.1.2
|
79 |
+
minydra==0.1.6
|
80 |
+
multidict==6.0.2
|
81 |
+
murmurhash==1.0.9
|
82 |
+
mypy-extensions==0.4.3
|
83 |
+
networkx==2.6.2
|
84 |
+
numpy==1.21.2
|
85 |
+
nvidia-ml-py3==7.352.0
|
86 |
+
omegaconf==1.4.1
|
87 |
+
opencv-python==4.5.3.56
|
88 |
+
orjson==3.8.0
|
89 |
+
packaging==21.0
|
90 |
+
pandas==1.3.2
|
91 |
+
paramiko==2.11.0
|
92 |
+
parso==0.8.2
|
93 |
+
pathspec==0.10.1
|
94 |
+
pathy==0.6.2
|
95 |
+
pexpect==4.8.0
|
96 |
+
pickleshare==0.7.5
|
97 |
+
Pillow==8.3.2
|
98 |
+
platformdirs==2.5.2
|
99 |
+
plotly==5.3.1
|
100 |
+
preshed==3.0.8
|
101 |
+
prompt-toolkit==3.0.20
|
102 |
+
ptyprocess==0.7.0
|
103 |
+
pure-eval==0.2.2
|
104 |
+
py-cpuinfo==8.0.0
|
105 |
+
pycodestyle==2.9.1
|
106 |
+
pycparser==2.21
|
107 |
+
pycryptodome==3.15.0
|
108 |
+
pydantic==1.10.2
|
109 |
+
pydub==0.25.1
|
110 |
+
pyflakes==2.5.0
|
111 |
+
Pygments==2.10.0
|
112 |
+
PyNaCl==1.5.0
|
113 |
+
pynvml==11.0.0
|
114 |
+
pyparsing==2.4.7
|
115 |
+
pyrsistent==0.18.0
|
116 |
+
PySocks==1.7.1
|
117 |
+
python-dateutil==2.8.2
|
118 |
+
python-multipart==0.0.5
|
119 |
+
pytorch-ranger==0.1.1
|
120 |
+
pytz==2021.1
|
121 |
+
PyWavelets==1.1.1
|
122 |
+
PyYAML==5.4.1
|
123 |
+
regex==2022.9.13
|
124 |
+
requests==2.26.0
|
125 |
+
requests-toolbelt==0.9.1
|
126 |
+
rfc3986==1.5.0
|
127 |
+
rich==12.6.0
|
128 |
+
scikit-image==0.18.3
|
129 |
+
scikit-learn==0.24.2
|
130 |
+
scipy==1.7.1
|
131 |
+
seaborn==0.11.2
|
132 |
+
semantic-version==2.8.5
|
133 |
+
six==1.16.0
|
134 |
+
smart-open==5.2.1
|
135 |
+
sniffio==1.3.0
|
136 |
+
spacy==3.4.2
|
137 |
+
spacy-legacy==3.0.10
|
138 |
+
spacy-loggers==1.0.3
|
139 |
+
srsly==2.4.5
|
140 |
+
stack-data==0.5.1
|
141 |
+
starlette==0.20.4
|
142 |
+
tenacity==8.0.1
|
143 |
+
termcolor==1.1.0
|
144 |
+
thinc==8.1.5
|
145 |
+
threadpoolctl==2.2.0
|
146 |
+
tifffile==2021.8.30
|
147 |
+
tokenizers==0.13.1
|
148 |
+
tomli==2.0.1
|
149 |
+
torch==1.7.1
|
150 |
+
torch-optimizer==0.1.0
|
151 |
+
torchvision==0.8.2
|
152 |
+
tqdm==4.62.2
|
153 |
+
traitlets==5.1.0
|
154 |
+
transformers==4.23.1
|
155 |
+
typer==0.4.2
|
156 |
+
typing_extensions==4.4.0
|
157 |
+
tzlocal==2.1
|
158 |
+
uc-micro-py==1.0.1
|
159 |
+
urllib3==1.26.6
|
160 |
+
uvicorn==0.19.0
|
161 |
+
wasabi==0.10.1
|
162 |
+
wcwidth==0.2.5
|
163 |
+
websocket-client==1.2.1
|
164 |
+
websockets==10.3
|
165 |
+
Werkzeug==2.0.1
|
166 |
+
wrapt==1.12.1
|
167 |
+
wurlitzer==3.0.2
|
168 |
+
yarl==1.8.1
|
169 |
+
zipp==3.10.0
|