multimodalart HF staff commited on
Commit
27486b3
Β·
1 Parent(s): 2799450

Upload 37 files

Browse files
configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-6
3
+ target: sgm.models.autoencoder.AutoencodingEngine
4
+ params:
5
+ input_key: jpg
6
+ monitor: val/rec_loss
7
+
8
+ loss_config:
9
+ target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
10
+ params:
11
+ perceptual_weight: 0.25
12
+ disc_start: 20001
13
+ disc_weight: 0.5
14
+ learn_logvar: True
15
+
16
+ regularization_weights:
17
+ kl_loss: 1.0
18
+
19
+ regularizer_config:
20
+ target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
21
+
22
+ encoder_config:
23
+ target: sgm.modules.diffusionmodules.model.Encoder
24
+ params:
25
+ attn_type: none
26
+ double_z: True
27
+ z_channels: 4
28
+ resolution: 256
29
+ in_channels: 3
30
+ out_ch: 3
31
+ ch: 128
32
+ ch_mult: [1, 2, 4]
33
+ num_res_blocks: 4
34
+ attn_resolutions: []
35
+ dropout: 0.0
36
+
37
+ decoder_config:
38
+ target: sgm.modules.diffusionmodules.model.Decoder
39
+ params: ${model.params.encoder_config.params}
40
+
41
+ data:
42
+ target: sgm.data.dataset.StableDataModuleFromConfig
43
+ params:
44
+ train:
45
+ datapipeline:
46
+ urls:
47
+ - DATA-PATH
48
+ pipeline_config:
49
+ shardshuffle: 10000
50
+ sample_shuffle: 10000
51
+
52
+ decoders:
53
+ - pil
54
+
55
+ postprocessors:
56
+ - target: sdata.mappers.TorchVisionImageTransforms
57
+ params:
58
+ key: jpg
59
+ transforms:
60
+ - target: torchvision.transforms.Resize
61
+ params:
62
+ size: 256
63
+ interpolation: 3
64
+ - target: torchvision.transforms.ToTensor
65
+ - target: sdata.mappers.Rescaler
66
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
67
+ params:
68
+ h_key: height
69
+ w_key: width
70
+
71
+ loader:
72
+ batch_size: 8
73
+ num_workers: 4
74
+
75
+
76
+ lightning:
77
+ strategy:
78
+ target: pytorch_lightning.strategies.DDPStrategy
79
+ params:
80
+ find_unused_parameters: True
81
+
82
+ modelcheckpoint:
83
+ params:
84
+ every_n_train_steps: 5000
85
+
86
+ callbacks:
87
+ metrics_over_trainsteps_checkpoint:
88
+ params:
89
+ every_n_train_steps: 50000
90
+
91
+ image_logger:
92
+ target: main.ImageLogger
93
+ params:
94
+ enable_autocast: False
95
+ batch_frequency: 1000
96
+ max_images: 8
97
+ increase_log_steps: True
98
+
99
+ trainer:
100
+ devices: 0,
101
+ limit_val_batches: 50
102
+ benchmark: True
103
+ accumulate_grad_batches: 1
104
+ val_check_interval: 10000
configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 4.5e-6
3
+ target: sgm.models.autoencoder.AutoencodingEngine
4
+ params:
5
+ input_key: jpg
6
+ monitor: val/loss/rec
7
+ disc_start_iter: 0
8
+
9
+ encoder_config:
10
+ target: sgm.modules.diffusionmodules.model.Encoder
11
+ params:
12
+ attn_type: vanilla-xformers
13
+ double_z: true
14
+ z_channels: 8
15
+ resolution: 256
16
+ in_channels: 3
17
+ out_ch: 3
18
+ ch: 128
19
+ ch_mult: [1, 2, 4, 4]
20
+ num_res_blocks: 2
21
+ attn_resolutions: []
22
+ dropout: 0.0
23
+
24
+ decoder_config:
25
+ target: sgm.modules.diffusionmodules.model.Decoder
26
+ params: ${model.params.encoder_config.params}
27
+
28
+ regularizer_config:
29
+ target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
30
+
31
+ loss_config:
32
+ target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
33
+ params:
34
+ perceptual_weight: 0.25
35
+ disc_start: 20001
36
+ disc_weight: 0.5
37
+ learn_logvar: True
38
+
39
+ regularization_weights:
40
+ kl_loss: 1.0
41
+
42
+ data:
43
+ target: sgm.data.dataset.StableDataModuleFromConfig
44
+ params:
45
+ train:
46
+ datapipeline:
47
+ urls:
48
+ - DATA-PATH
49
+ pipeline_config:
50
+ shardshuffle: 10000
51
+ sample_shuffle: 10000
52
+
53
+ decoders:
54
+ - pil
55
+
56
+ postprocessors:
57
+ - target: sdata.mappers.TorchVisionImageTransforms
58
+ params:
59
+ key: jpg
60
+ transforms:
61
+ - target: torchvision.transforms.Resize
62
+ params:
63
+ size: 256
64
+ interpolation: 3
65
+ - target: torchvision.transforms.ToTensor
66
+ - target: sdata.mappers.Rescaler
67
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
68
+ params:
69
+ h_key: height
70
+ w_key: width
71
+
72
+ loader:
73
+ batch_size: 8
74
+ num_workers: 4
75
+
76
+
77
+ lightning:
78
+ strategy:
79
+ target: pytorch_lightning.strategies.DDPStrategy
80
+ params:
81
+ find_unused_parameters: True
82
+
83
+ modelcheckpoint:
84
+ params:
85
+ every_n_train_steps: 5000
86
+
87
+ callbacks:
88
+ metrics_over_trainsteps_checkpoint:
89
+ params:
90
+ every_n_train_steps: 50000
91
+
92
+ image_logger:
93
+ target: main.ImageLogger
94
+ params:
95
+ enable_autocast: False
96
+ batch_frequency: 1000
97
+ max_images: 8
98
+ increase_log_steps: True
99
+
100
+ trainer:
101
+ devices: 0,
102
+ limit_val_batches: 50
103
+ benchmark: True
104
+ accumulate_grad_batches: 1
105
+ val_check_interval: 10000
configs/example_training/imagenet-f8_cond.yaml ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ scale_factor: 0.13025
6
+ disable_first_stage_autocast: True
7
+ log_keys:
8
+ - cls
9
+
10
+ scheduler_config:
11
+ target: sgm.lr_scheduler.LambdaLinearScheduler
12
+ params:
13
+ warm_up_steps: [10000]
14
+ cycle_lengths: [10000000000000]
15
+ f_start: [1.e-6]
16
+ f_max: [1.]
17
+ f_min: [1.]
18
+
19
+ denoiser_config:
20
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
21
+ params:
22
+ num_idx: 1000
23
+
24
+ scaling_config:
25
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
26
+ discretization_config:
27
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
28
+
29
+ network_config:
30
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ use_checkpoint: True
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 256
36
+ attention_resolutions: [1, 2, 4]
37
+ num_res_blocks: 2
38
+ channel_mult: [1, 2, 4]
39
+ num_head_channels: 64
40
+ num_classes: sequential
41
+ adm_in_channels: 1024
42
+ transformer_depth: 1
43
+ context_dim: 1024
44
+ spatial_transformer_attn_type: softmax-xformers
45
+
46
+ conditioner_config:
47
+ target: sgm.modules.GeneralConditioner
48
+ params:
49
+ emb_models:
50
+ - is_trainable: True
51
+ input_key: cls
52
+ ucg_rate: 0.2
53
+ target: sgm.modules.encoders.modules.ClassEmbedder
54
+ params:
55
+ add_sequence_dim: True
56
+ embed_dim: 1024
57
+ n_classes: 1000
58
+
59
+ - is_trainable: False
60
+ ucg_rate: 0.2
61
+ input_key: original_size_as_tuple
62
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
63
+ params:
64
+ outdim: 256
65
+
66
+ - is_trainable: False
67
+ input_key: crop_coords_top_left
68
+ ucg_rate: 0.2
69
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
70
+ params:
71
+ outdim: 256
72
+
73
+ first_stage_config:
74
+ target: sgm.models.autoencoder.AutoencoderKL
75
+ params:
76
+ ckpt_path: CKPT_PATH
77
+ embed_dim: 4
78
+ monitor: val/rec_loss
79
+ ddconfig:
80
+ attn_type: vanilla-xformers
81
+ double_z: true
82
+ z_channels: 4
83
+ resolution: 256
84
+ in_channels: 3
85
+ out_ch: 3
86
+ ch: 128
87
+ ch_mult: [1, 2, 4, 4]
88
+ num_res_blocks: 2
89
+ attn_resolutions: []
90
+ dropout: 0.0
91
+ lossconfig:
92
+ target: torch.nn.Identity
93
+
94
+ loss_fn_config:
95
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
96
+ params:
97
+ loss_weighting_config:
98
+ target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
99
+ sigma_sampler_config:
100
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
101
+ params:
102
+ num_idx: 1000
103
+
104
+ discretization_config:
105
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
106
+
107
+ sampler_config:
108
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
109
+ params:
110
+ num_steps: 50
111
+
112
+ discretization_config:
113
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
114
+
115
+ guider_config:
116
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
117
+ params:
118
+ scale: 5.0
119
+
120
+ data:
121
+ target: sgm.data.dataset.StableDataModuleFromConfig
122
+ params:
123
+ train:
124
+ datapipeline:
125
+ urls:
126
+ # USER: adapt this path the root of your custom dataset
127
+ - DATA_PATH
128
+ pipeline_config:
129
+ shardshuffle: 10000
130
+ sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
131
+
132
+ decoders:
133
+ - pil
134
+
135
+ postprocessors:
136
+ - target: sdata.mappers.TorchVisionImageTransforms
137
+ params:
138
+ key: jpg # USER: you might wanna adapt this for your custom dataset
139
+ transforms:
140
+ - target: torchvision.transforms.Resize
141
+ params:
142
+ size: 256
143
+ interpolation: 3
144
+ - target: torchvision.transforms.ToTensor
145
+ - target: sdata.mappers.Rescaler
146
+
147
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
148
+ params:
149
+ h_key: height # USER: you might wanna adapt this for your custom dataset
150
+ w_key: width # USER: you might wanna adapt this for your custom dataset
151
+
152
+ loader:
153
+ batch_size: 64
154
+ num_workers: 6
155
+
156
+ lightning:
157
+ modelcheckpoint:
158
+ params:
159
+ every_n_train_steps: 5000
160
+
161
+ callbacks:
162
+ metrics_over_trainsteps_checkpoint:
163
+ params:
164
+ every_n_train_steps: 25000
165
+
166
+ image_logger:
167
+ target: main.ImageLogger
168
+ params:
169
+ disabled: False
170
+ enable_autocast: False
171
+ batch_frequency: 1000
172
+ max_images: 8
173
+ increase_log_steps: True
174
+ log_first_step: False
175
+ log_images_kwargs:
176
+ use_ema_scope: False
177
+ N: 8
178
+ n_rows: 2
179
+
180
+ trainer:
181
+ devices: 0,
182
+ benchmark: True
183
+ num_sanity_val_steps: 0
184
+ accumulate_grad_batches: 1
185
+ max_epochs: 1000
configs/example_training/toy/cifar10_cond.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ denoiser_config:
6
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
7
+ params:
8
+ scaling_config:
9
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
10
+ params:
11
+ sigma_data: 1.0
12
+
13
+ network_config:
14
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
15
+ params:
16
+ in_channels: 3
17
+ out_channels: 3
18
+ model_channels: 32
19
+ attention_resolutions: []
20
+ num_res_blocks: 4
21
+ channel_mult: [1, 2, 2]
22
+ num_head_channels: 32
23
+ num_classes: sequential
24
+ adm_in_channels: 128
25
+
26
+ conditioner_config:
27
+ target: sgm.modules.GeneralConditioner
28
+ params:
29
+ emb_models:
30
+ - is_trainable: True
31
+ input_key: cls
32
+ ucg_rate: 0.2
33
+ target: sgm.modules.encoders.modules.ClassEmbedder
34
+ params:
35
+ embed_dim: 128
36
+ n_classes: 10
37
+
38
+ first_stage_config:
39
+ target: sgm.models.autoencoder.IdentityFirstStage
40
+
41
+ loss_fn_config:
42
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
43
+ params:
44
+ loss_weighting_config:
45
+ target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
46
+ params:
47
+ sigma_data: 1.0
48
+ sigma_sampler_config:
49
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
50
+
51
+ sampler_config:
52
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
53
+ params:
54
+ num_steps: 50
55
+
56
+ discretization_config:
57
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
58
+
59
+ guider_config:
60
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
61
+ params:
62
+ scale: 3.0
63
+
64
+ data:
65
+ target: sgm.data.cifar10.CIFAR10Loader
66
+ params:
67
+ batch_size: 512
68
+ num_workers: 1
69
+
70
+ lightning:
71
+ modelcheckpoint:
72
+ params:
73
+ every_n_train_steps: 5000
74
+
75
+ callbacks:
76
+ metrics_over_trainsteps_checkpoint:
77
+ params:
78
+ every_n_train_steps: 25000
79
+
80
+ image_logger:
81
+ target: main.ImageLogger
82
+ params:
83
+ disabled: False
84
+ batch_frequency: 1000
85
+ max_images: 64
86
+ increase_log_steps: True
87
+ log_first_step: False
88
+ log_images_kwargs:
89
+ use_ema_scope: False
90
+ N: 64
91
+ n_rows: 8
92
+
93
+ trainer:
94
+ devices: 0,
95
+ benchmark: True
96
+ num_sanity_val_steps: 0
97
+ accumulate_grad_batches: 1
98
+ max_epochs: 20
configs/example_training/toy/mnist.yaml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ denoiser_config:
6
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
7
+ params:
8
+ scaling_config:
9
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
10
+ params:
11
+ sigma_data: 1.0
12
+
13
+ network_config:
14
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
15
+ params:
16
+ in_channels: 1
17
+ out_channels: 1
18
+ model_channels: 32
19
+ attention_resolutions: []
20
+ num_res_blocks: 4
21
+ channel_mult: [1, 2, 2]
22
+ num_head_channels: 32
23
+
24
+ first_stage_config:
25
+ target: sgm.models.autoencoder.IdentityFirstStage
26
+
27
+ loss_fn_config:
28
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
29
+ params:
30
+ loss_weighting_config:
31
+ target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
32
+ params:
33
+ sigma_data: 1.0
34
+ sigma_sampler_config:
35
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
36
+
37
+ sampler_config:
38
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
39
+ params:
40
+ num_steps: 50
41
+
42
+ discretization_config:
43
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
44
+
45
+ data:
46
+ target: sgm.data.mnist.MNISTLoader
47
+ params:
48
+ batch_size: 512
49
+ num_workers: 1
50
+
51
+ lightning:
52
+ modelcheckpoint:
53
+ params:
54
+ every_n_train_steps: 5000
55
+
56
+ callbacks:
57
+ metrics_over_trainsteps_checkpoint:
58
+ params:
59
+ every_n_train_steps: 25000
60
+
61
+ image_logger:
62
+ target: main.ImageLogger
63
+ params:
64
+ disabled: False
65
+ batch_frequency: 1000
66
+ max_images: 64
67
+ increase_log_steps: False
68
+ log_first_step: False
69
+ log_images_kwargs:
70
+ use_ema_scope: False
71
+ N: 64
72
+ n_rows: 8
73
+
74
+ trainer:
75
+ devices: 0,
76
+ benchmark: True
77
+ num_sanity_val_steps: 0
78
+ accumulate_grad_batches: 1
79
+ max_epochs: 10
configs/example_training/toy/mnist_cond.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ denoiser_config:
6
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
7
+ params:
8
+ scaling_config:
9
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
10
+ params:
11
+ sigma_data: 1.0
12
+
13
+ network_config:
14
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
15
+ params:
16
+ in_channels: 1
17
+ out_channels: 1
18
+ model_channels: 32
19
+ attention_resolutions: []
20
+ num_res_blocks: 4
21
+ channel_mult: [1, 2, 2]
22
+ num_head_channels: 32
23
+ num_classes: sequential
24
+ adm_in_channels: 128
25
+
26
+ conditioner_config:
27
+ target: sgm.modules.GeneralConditioner
28
+ params:
29
+ emb_models:
30
+ - is_trainable: True
31
+ input_key: cls
32
+ ucg_rate: 0.2
33
+ target: sgm.modules.encoders.modules.ClassEmbedder
34
+ params:
35
+ embed_dim: 128
36
+ n_classes: 10
37
+
38
+ first_stage_config:
39
+ target: sgm.models.autoencoder.IdentityFirstStage
40
+
41
+ loss_fn_config:
42
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
43
+ params:
44
+ loss_weighting_config:
45
+ target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
46
+ params:
47
+ sigma_data: 1.0
48
+ sigma_sampler_config:
49
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
50
+
51
+ sampler_config:
52
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
53
+ params:
54
+ num_steps: 50
55
+
56
+ discretization_config:
57
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
58
+
59
+ guider_config:
60
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
61
+ params:
62
+ scale: 3.0
63
+
64
+ data:
65
+ target: sgm.data.mnist.MNISTLoader
66
+ params:
67
+ batch_size: 512
68
+ num_workers: 1
69
+
70
+ lightning:
71
+ modelcheckpoint:
72
+ params:
73
+ every_n_train_steps: 5000
74
+
75
+ callbacks:
76
+ metrics_over_trainsteps_checkpoint:
77
+ params:
78
+ every_n_train_steps: 25000
79
+
80
+ image_logger:
81
+ target: main.ImageLogger
82
+ params:
83
+ disabled: False
84
+ batch_frequency: 1000
85
+ max_images: 16
86
+ increase_log_steps: True
87
+ log_first_step: False
88
+ log_images_kwargs:
89
+ use_ema_scope: False
90
+ N: 16
91
+ n_rows: 4
92
+
93
+ trainer:
94
+ devices: 0,
95
+ benchmark: True
96
+ num_sanity_val_steps: 0
97
+ accumulate_grad_batches: 1
98
+ max_epochs: 20
configs/example_training/toy/mnist_cond_discrete_eps.yaml ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ denoiser_config:
6
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
7
+ params:
8
+ num_idx: 1000
9
+
10
+ scaling_config:
11
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
12
+ discretization_config:
13
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
14
+
15
+ network_config:
16
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
17
+ params:
18
+ in_channels: 1
19
+ out_channels: 1
20
+ model_channels: 32
21
+ attention_resolutions: []
22
+ num_res_blocks: 4
23
+ channel_mult: [1, 2, 2]
24
+ num_head_channels: 32
25
+ num_classes: sequential
26
+ adm_in_channels: 128
27
+
28
+ conditioner_config:
29
+ target: sgm.modules.GeneralConditioner
30
+ params:
31
+ emb_models:
32
+ - is_trainable: True
33
+ input_key: cls
34
+ ucg_rate: 0.2
35
+ target: sgm.modules.encoders.modules.ClassEmbedder
36
+ params:
37
+ embed_dim: 128
38
+ n_classes: 10
39
+
40
+ first_stage_config:
41
+ target: sgm.models.autoencoder.IdentityFirstStage
42
+
43
+ loss_fn_config:
44
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
45
+ params:
46
+ loss_weighting_config:
47
+ target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
48
+ sigma_sampler_config:
49
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
50
+ params:
51
+ num_idx: 1000
52
+
53
+ discretization_config:
54
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
55
+
56
+ sampler_config:
57
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
58
+ params:
59
+ num_steps: 50
60
+
61
+ discretization_config:
62
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
63
+
64
+ guider_config:
65
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
66
+ params:
67
+ scale: 5.0
68
+
69
+ data:
70
+ target: sgm.data.mnist.MNISTLoader
71
+ params:
72
+ batch_size: 512
73
+ num_workers: 1
74
+
75
+ lightning:
76
+ modelcheckpoint:
77
+ params:
78
+ every_n_train_steps: 5000
79
+
80
+ callbacks:
81
+ metrics_over_trainsteps_checkpoint:
82
+ params:
83
+ every_n_train_steps: 25000
84
+
85
+ image_logger:
86
+ target: main.ImageLogger
87
+ params:
88
+ disabled: False
89
+ batch_frequency: 1000
90
+ max_images: 16
91
+ increase_log_steps: True
92
+ log_first_step: False
93
+ log_images_kwargs:
94
+ use_ema_scope: False
95
+ N: 16
96
+ n_rows: 4
97
+
98
+ trainer:
99
+ devices: 0,
100
+ benchmark: True
101
+ num_sanity_val_steps: 0
102
+ accumulate_grad_batches: 1
103
+ max_epochs: 20
configs/example_training/toy/mnist_cond_l1_loss.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ denoiser_config:
6
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
7
+ params:
8
+ scaling_config:
9
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
10
+ params:
11
+ sigma_data: 1.0
12
+
13
+ network_config:
14
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
15
+ params:
16
+ in_channels: 1
17
+ out_channels: 1
18
+ model_channels: 32
19
+ attention_resolutions: []
20
+ num_res_blocks: 4
21
+ channel_mult: [1, 2, 2]
22
+ num_head_channels: 32
23
+ num_classes: sequential
24
+ adm_in_channels: 128
25
+
26
+ conditioner_config:
27
+ target: sgm.modules.GeneralConditioner
28
+ params:
29
+ emb_models:
30
+ - is_trainable: True
31
+ input_key: cls
32
+ ucg_rate: 0.2
33
+ target: sgm.modules.encoders.modules.ClassEmbedder
34
+ params:
35
+ embed_dim: 128
36
+ n_classes: 10
37
+
38
+ first_stage_config:
39
+ target: sgm.models.autoencoder.IdentityFirstStage
40
+
41
+ loss_fn_config:
42
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
43
+ params:
44
+ loss_type: l1
45
+ loss_weighting_config:
46
+ target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
47
+ params:
48
+ sigma_data: 1.0
49
+ sigma_sampler_config:
50
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
51
+
52
+ sampler_config:
53
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
54
+ params:
55
+ num_steps: 50
56
+
57
+ discretization_config:
58
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
59
+
60
+ guider_config:
61
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
62
+ params:
63
+ scale: 3.0
64
+
65
+ data:
66
+ target: sgm.data.mnist.MNISTLoader
67
+ params:
68
+ batch_size: 512
69
+ num_workers: 1
70
+
71
+ lightning:
72
+ modelcheckpoint:
73
+ params:
74
+ every_n_train_steps: 5000
75
+
76
+ callbacks:
77
+ metrics_over_trainsteps_checkpoint:
78
+ params:
79
+ every_n_train_steps: 25000
80
+
81
+ image_logger:
82
+ target: main.ImageLogger
83
+ params:
84
+ disabled: False
85
+ batch_frequency: 1000
86
+ max_images: 64
87
+ increase_log_steps: True
88
+ log_first_step: False
89
+ log_images_kwargs:
90
+ use_ema_scope: False
91
+ N: 64
92
+ n_rows: 8
93
+
94
+ trainer:
95
+ devices: 0,
96
+ benchmark: True
97
+ num_sanity_val_steps: 0
98
+ accumulate_grad_batches: 1
99
+ max_epochs: 20
configs/example_training/toy/mnist_cond_with_ema.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ use_ema: True
6
+
7
+ denoiser_config:
8
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
9
+ params:
10
+ scaling_config:
11
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
12
+ params:
13
+ sigma_data: 1.0
14
+
15
+ network_config:
16
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
17
+ params:
18
+ in_channels: 1
19
+ out_channels: 1
20
+ model_channels: 32
21
+ attention_resolutions: []
22
+ num_res_blocks: 4
23
+ channel_mult: [1, 2, 2]
24
+ num_head_channels: 32
25
+ num_classes: sequential
26
+ adm_in_channels: 128
27
+
28
+ conditioner_config:
29
+ target: sgm.modules.GeneralConditioner
30
+ params:
31
+ emb_models:
32
+ - is_trainable: True
33
+ input_key: cls
34
+ ucg_rate: 0.2
35
+ target: sgm.modules.encoders.modules.ClassEmbedder
36
+ params:
37
+ embed_dim: 128
38
+ n_classes: 10
39
+
40
+ first_stage_config:
41
+ target: sgm.models.autoencoder.IdentityFirstStage
42
+
43
+ loss_fn_config:
44
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
45
+ params:
46
+ loss_weighting_config:
47
+ target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
48
+ params:
49
+ sigma_data: 1.0
50
+ sigma_sampler_config:
51
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
52
+
53
+ sampler_config:
54
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
55
+ params:
56
+ num_steps: 50
57
+
58
+ discretization_config:
59
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
60
+
61
+ guider_config:
62
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
63
+ params:
64
+ scale: 3.0
65
+
66
+ data:
67
+ target: sgm.data.mnist.MNISTLoader
68
+ params:
69
+ batch_size: 512
70
+ num_workers: 1
71
+
72
+ lightning:
73
+ modelcheckpoint:
74
+ params:
75
+ every_n_train_steps: 5000
76
+
77
+ callbacks:
78
+ metrics_over_trainsteps_checkpoint:
79
+ params:
80
+ every_n_train_steps: 25000
81
+
82
+ image_logger:
83
+ target: main.ImageLogger
84
+ params:
85
+ disabled: False
86
+ batch_frequency: 1000
87
+ max_images: 64
88
+ increase_log_steps: True
89
+ log_first_step: False
90
+ log_images_kwargs:
91
+ use_ema_scope: False
92
+ N: 64
93
+ n_rows: 8
94
+
95
+ trainer:
96
+ devices: 0,
97
+ benchmark: True
98
+ num_sanity_val_steps: 0
99
+ accumulate_grad_batches: 1
100
+ max_epochs: 20
configs/example_training/txt2img-clipl-legacy-ucg-training.yaml ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ scale_factor: 0.13025
6
+ disable_first_stage_autocast: True
7
+ log_keys:
8
+ - txt
9
+
10
+ scheduler_config:
11
+ target: sgm.lr_scheduler.LambdaLinearScheduler
12
+ params:
13
+ warm_up_steps: [10000]
14
+ cycle_lengths: [10000000000000]
15
+ f_start: [1.e-6]
16
+ f_max: [1.]
17
+ f_min: [1.]
18
+
19
+ denoiser_config:
20
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
21
+ params:
22
+ num_idx: 1000
23
+
24
+ scaling_config:
25
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
26
+ discretization_config:
27
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
28
+
29
+ network_config:
30
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ use_checkpoint: True
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [1, 2, 4]
37
+ num_res_blocks: 2
38
+ channel_mult: [1, 2, 4, 4]
39
+ num_head_channels: 64
40
+ num_classes: sequential
41
+ adm_in_channels: 1792
42
+ num_heads: 1
43
+ transformer_depth: 1
44
+ context_dim: 768
45
+ spatial_transformer_attn_type: softmax-xformers
46
+
47
+ conditioner_config:
48
+ target: sgm.modules.GeneralConditioner
49
+ params:
50
+ emb_models:
51
+ - is_trainable: True
52
+ input_key: txt
53
+ ucg_rate: 0.1
54
+ legacy_ucg_value: ""
55
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
56
+ params:
57
+ always_return_pooled: True
58
+
59
+ - is_trainable: False
60
+ ucg_rate: 0.1
61
+ input_key: original_size_as_tuple
62
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
63
+ params:
64
+ outdim: 256
65
+
66
+ - is_trainable: False
67
+ input_key: crop_coords_top_left
68
+ ucg_rate: 0.1
69
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
70
+ params:
71
+ outdim: 256
72
+
73
+ first_stage_config:
74
+ target: sgm.models.autoencoder.AutoencoderKL
75
+ params:
76
+ ckpt_path: CKPT_PATH
77
+ embed_dim: 4
78
+ monitor: val/rec_loss
79
+ ddconfig:
80
+ attn_type: vanilla-xformers
81
+ double_z: true
82
+ z_channels: 4
83
+ resolution: 256
84
+ in_channels: 3
85
+ out_ch: 3
86
+ ch: 128
87
+ ch_mult: [ 1, 2, 4, 4 ]
88
+ num_res_blocks: 2
89
+ attn_resolutions: [ ]
90
+ dropout: 0.0
91
+ lossconfig:
92
+ target: torch.nn.Identity
93
+
94
+ loss_fn_config:
95
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
96
+ params:
97
+ loss_weighting_config:
98
+ target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
99
+ sigma_sampler_config:
100
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
101
+ params:
102
+ num_idx: 1000
103
+
104
+ discretization_config:
105
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
106
+
107
+ sampler_config:
108
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
109
+ params:
110
+ num_steps: 50
111
+
112
+ discretization_config:
113
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
114
+
115
+ guider_config:
116
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
117
+ params:
118
+ scale: 7.5
119
+
120
+ data:
121
+ target: sgm.data.dataset.StableDataModuleFromConfig
122
+ params:
123
+ train:
124
+ datapipeline:
125
+ urls:
126
+ # USER: adapt this path the root of your custom dataset
127
+ - DATA_PATH
128
+ pipeline_config:
129
+ shardshuffle: 10000
130
+ sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
131
+
132
+ decoders:
133
+ - pil
134
+
135
+ postprocessors:
136
+ - target: sdata.mappers.TorchVisionImageTransforms
137
+ params:
138
+ key: jpg # USER: you might wanna adapt this for your custom dataset
139
+ transforms:
140
+ - target: torchvision.transforms.Resize
141
+ params:
142
+ size: 256
143
+ interpolation: 3
144
+ - target: torchvision.transforms.ToTensor
145
+ - target: sdata.mappers.Rescaler
146
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
147
+ # USER: you might wanna use non-default parameters due to your custom dataset
148
+
149
+ loader:
150
+ batch_size: 64
151
+ num_workers: 6
152
+
153
+ lightning:
154
+ modelcheckpoint:
155
+ params:
156
+ every_n_train_steps: 5000
157
+
158
+ callbacks:
159
+ metrics_over_trainsteps_checkpoint:
160
+ params:
161
+ every_n_train_steps: 25000
162
+
163
+ image_logger:
164
+ target: main.ImageLogger
165
+ params:
166
+ disabled: False
167
+ enable_autocast: False
168
+ batch_frequency: 1000
169
+ max_images: 8
170
+ increase_log_steps: True
171
+ log_first_step: False
172
+ log_images_kwargs:
173
+ use_ema_scope: False
174
+ N: 8
175
+ n_rows: 2
176
+
177
+ trainer:
178
+ devices: 0,
179
+ benchmark: True
180
+ num_sanity_val_steps: 0
181
+ accumulate_grad_batches: 1
182
+ max_epochs: 1000
configs/example_training/txt2img-clipl.yaml ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ scale_factor: 0.13025
6
+ disable_first_stage_autocast: True
7
+ log_keys:
8
+ - txt
9
+
10
+ scheduler_config:
11
+ target: sgm.lr_scheduler.LambdaLinearScheduler
12
+ params:
13
+ warm_up_steps: [10000]
14
+ cycle_lengths: [10000000000000]
15
+ f_start: [1.e-6]
16
+ f_max: [1.]
17
+ f_min: [1.]
18
+
19
+ denoiser_config:
20
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
21
+ params:
22
+ num_idx: 1000
23
+
24
+ scaling_config:
25
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
26
+ discretization_config:
27
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
28
+
29
+ network_config:
30
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ use_checkpoint: True
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [1, 2, 4]
37
+ num_res_blocks: 2
38
+ channel_mult: [1, 2, 4, 4]
39
+ num_head_channels: 64
40
+ num_classes: sequential
41
+ adm_in_channels: 1792
42
+ num_heads: 1
43
+ transformer_depth: 1
44
+ context_dim: 768
45
+ spatial_transformer_attn_type: softmax-xformers
46
+
47
+ conditioner_config:
48
+ target: sgm.modules.GeneralConditioner
49
+ params:
50
+ emb_models:
51
+ - is_trainable: True
52
+ input_key: txt
53
+ ucg_rate: 0.1
54
+ legacy_ucg_value: ""
55
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
56
+ params:
57
+ always_return_pooled: True
58
+
59
+ - is_trainable: False
60
+ ucg_rate: 0.1
61
+ input_key: original_size_as_tuple
62
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
63
+ params:
64
+ outdim: 256
65
+
66
+ - is_trainable: False
67
+ input_key: crop_coords_top_left
68
+ ucg_rate: 0.1
69
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
70
+ params:
71
+ outdim: 256
72
+
73
+ first_stage_config:
74
+ target: sgm.models.autoencoder.AutoencoderKL
75
+ params:
76
+ ckpt_path: CKPT_PATH
77
+ embed_dim: 4
78
+ monitor: val/rec_loss
79
+ ddconfig:
80
+ attn_type: vanilla-xformers
81
+ double_z: true
82
+ z_channels: 4
83
+ resolution: 256
84
+ in_channels: 3
85
+ out_ch: 3
86
+ ch: 128
87
+ ch_mult: [1, 2, 4, 4]
88
+ num_res_blocks: 2
89
+ attn_resolutions: []
90
+ dropout: 0.0
91
+ lossconfig:
92
+ target: torch.nn.Identity
93
+
94
+ loss_fn_config:
95
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
96
+ params:
97
+ loss_weighting_config:
98
+ target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
99
+ sigma_sampler_config:
100
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
101
+ params:
102
+ num_idx: 1000
103
+
104
+ discretization_config:
105
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
106
+
107
+ sampler_config:
108
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
109
+ params:
110
+ num_steps: 50
111
+
112
+ discretization_config:
113
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
114
+
115
+ guider_config:
116
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFG
117
+ params:
118
+ scale: 7.5
119
+
120
+ data:
121
+ target: sgm.data.dataset.StableDataModuleFromConfig
122
+ params:
123
+ train:
124
+ datapipeline:
125
+ urls:
126
+ # USER: adapt this path the root of your custom dataset
127
+ - DATA_PATH
128
+ pipeline_config:
129
+ shardshuffle: 10000
130
+ sample_shuffle: 10000
131
+
132
+
133
+ decoders:
134
+ - pil
135
+
136
+ postprocessors:
137
+ - target: sdata.mappers.TorchVisionImageTransforms
138
+ params:
139
+ key: jpg # USER: you might wanna adapt this for your custom dataset
140
+ transforms:
141
+ - target: torchvision.transforms.Resize
142
+ params:
143
+ size: 256
144
+ interpolation: 3
145
+ - target: torchvision.transforms.ToTensor
146
+ - target: sdata.mappers.Rescaler
147
+ # USER: you might wanna use non-default parameters due to your custom dataset
148
+ - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
149
+ # USER: you might wanna use non-default parameters due to your custom dataset
150
+
151
+ loader:
152
+ batch_size: 64
153
+ num_workers: 6
154
+
155
+ lightning:
156
+ modelcheckpoint:
157
+ params:
158
+ every_n_train_steps: 5000
159
+
160
+ callbacks:
161
+ metrics_over_trainsteps_checkpoint:
162
+ params:
163
+ every_n_train_steps: 25000
164
+
165
+ image_logger:
166
+ target: main.ImageLogger
167
+ params:
168
+ disabled: False
169
+ enable_autocast: False
170
+ batch_frequency: 1000
171
+ max_images: 8
172
+ increase_log_steps: True
173
+ log_first_step: False
174
+ log_images_kwargs:
175
+ use_ema_scope: False
176
+ N: 8
177
+ n_rows: 2
178
+
179
+ trainer:
180
+ devices: 0,
181
+ benchmark: True
182
+ num_sanity_val_steps: 0
183
+ accumulate_grad_batches: 1
184
+ max_epochs: 1000
scripts/.DS_Store CHANGED
Binary files a/scripts/.DS_Store and b/scripts/.DS_Store differ
 
simple_video_sample.py CHANGED
@@ -18,8 +18,9 @@ from scripts.util.detection.nsfw_and_watermark_dectection import \
18
  from sgm.inference.helpers import embed_watermark
19
  from sgm.util import default, instantiate_from_config
20
 
 
21
  def sample(
22
- input_path: str = "assets/doggo.png", # Can either be image file or folder with image files
23
  num_frames: Optional[int] = None,
24
  num_steps: Optional[int] = None,
25
  version: str = "svd",
@@ -274,4 +275,4 @@ def load_model(
274
 
275
 
276
  if __name__ == "__main__":
277
- Fire(sample)
 
18
  from sgm.inference.helpers import embed_watermark
19
  from sgm.util import default, instantiate_from_config
20
 
21
+
22
  def sample(
23
+ input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
24
  num_frames: Optional[int] = None,
25
  num_steps: Optional[int] = None,
26
  version: str = "svd",
 
275
 
276
 
277
  if __name__ == "__main__":
278
+ Fire(sample)