|
{ |
|
"imports": [ |
|
"$import torch", |
|
"$from datetime import datetime", |
|
"$from pathlib import Path", |
|
"$from PIL import Image", |
|
"$from scripts.utils import visualize_2d_image" |
|
], |
|
"bundle_root": ".", |
|
"model_dir": "$@bundle_root + '/models'", |
|
"output_dir": "$@bundle_root + '/output'", |
|
"create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)", |
|
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", |
|
"output_postfix": "$datetime.now().strftime('sample_%Y%m%d_%H%M%S')", |
|
"channel": 0, |
|
"spatial_dims": 2, |
|
"image_channels": 1, |
|
"latent_channels": 1, |
|
"latent_shape": [ |
|
"@latent_channels", |
|
64, |
|
64 |
|
], |
|
"autoencoder_def": { |
|
"_target_": "generative.networks.nets.AutoencoderKL", |
|
"spatial_dims": "@spatial_dims", |
|
"in_channels": "@image_channels", |
|
"out_channels": "@image_channels", |
|
"latent_channels": "@latent_channels", |
|
"num_channels": [ |
|
64, |
|
128, |
|
256 |
|
], |
|
"num_res_blocks": 2, |
|
"norm_num_groups": 32, |
|
"norm_eps": 1e-06, |
|
"attention_levels": [ |
|
false, |
|
false, |
|
false |
|
], |
|
"with_encoder_nonlocal_attn": true, |
|
"with_decoder_nonlocal_attn": true |
|
}, |
|
"network_def": { |
|
"_target_": "generative.networks.nets.DiffusionModelUNet", |
|
"spatial_dims": "@spatial_dims", |
|
"in_channels": "@latent_channels", |
|
"out_channels": "@latent_channels", |
|
"num_channels": [ |
|
32, |
|
64, |
|
128, |
|
256 |
|
], |
|
"attention_levels": [ |
|
false, |
|
true, |
|
true, |
|
true |
|
], |
|
"num_head_channels": [ |
|
0, |
|
32, |
|
32, |
|
32 |
|
], |
|
"num_res_blocks": 2 |
|
}, |
|
"load_autoencoder_path": "$@bundle_root + '/models/model_autoencoder.pt'", |
|
"load_autoencoder": "$@autoencoder_def.load_state_dict(torch.load(@load_autoencoder_path))", |
|
"autoencoder": "$@autoencoder_def.to(@device)", |
|
"load_diffusion_path": "$@model_dir + '/model.pt'", |
|
"load_diffusion": "$@network_def.load_state_dict(torch.load(@load_diffusion_path))", |
|
"diffusion": "$@network_def.to(@device)", |
|
"noise_scheduler": { |
|
"_target_": "generative.networks.schedulers.DDIMScheduler", |
|
"_requires_": [ |
|
"@load_diffusion", |
|
"@load_autoencoder" |
|
], |
|
"num_train_timesteps": 1000, |
|
"beta_start": 0.0015, |
|
"beta_end": 0.0195, |
|
"beta_schedule": "scaled_linear", |
|
"clip_sample": false |
|
}, |
|
"noise": "$torch.randn([1]+@latent_shape).to(@device)", |
|
"set_timesteps": "$@noise_scheduler.set_timesteps(num_inference_steps=50)", |
|
"inferer": { |
|
"_target_": "scripts.ldm_sampler.LDMSampler", |
|
"_requires_": "@set_timesteps" |
|
}, |
|
"sample": "[email protected]_fn(@noise, @autoencoder, @diffusion, @noise_scheduler)", |
|
"generated_image": "$@sample", |
|
"generated_image_np": "$@generated_image[0,0].cpu().numpy().transpose(1, 0)[::-1, ::-1]", |
|
"img_pil": "$Image.fromarray(visualize_2d_image(@generated_image_np), 'RGB')", |
|
"run": [ |
|
"$@create_output_dir", |
|
"$@img_pil.save(@output_dir+'/synimg_'+@output_postfix+'.png')" |
|
] |
|
} |
|
|