geoffrey-dawson commited on
Commit
c802c28
·
verified ·
1 Parent(s): ac176b5

Upload config.yaml

Browse files
Files changed (1) hide show
  1. config.yaml +157 -0
config.yaml ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # lightning.pytorch==2.1.1
2
+ seed_everything: 0
3
+ trainer:
4
+ accelerator: gpu # we can also use auto or cpu
5
+ strategy: auto
6
+ devices: auto
7
+ num_nodes: 1
8
+ logger: True # will use tensorboardlogger
9
+
10
+ callbacks:
11
+ - class_path: RichProgressBar
12
+ - class_path: LearningRateMonitor
13
+ init_args:
14
+ logging_interval: epoch
15
+ - class_path: EarlyStopping
16
+ init_args:
17
+ monitor: val/loss
18
+ patience: 30
19
+
20
+ max_epochs: 200
21
+ check_val_every_n_epoch: 1
22
+ log_every_n_steps: 1
23
+ enable_checkpointing: true
24
+ default_root_dir: ./../data/fine_tuning/granite_geospatial_uki_flood_detection_v2
25
+ data:
26
+ class_path: GenericNonGeoSegmentationDataModule
27
+ init_args:
28
+ batch_size: 4
29
+ num_workers: 1
30
+ constant_scale: 0.0001
31
+ dataset_bands: # what bands are in your data
32
+ - VV
33
+ - VH
34
+ - BLUE
35
+ - GREEN
36
+ - RED
37
+ - NIR_NARROW
38
+ - SWIR_1
39
+ - SWIR_2
40
+ - CLOUD
41
+ output_bands: # which bands do you want to fine-tune
42
+ - BLUE
43
+ - GREEN
44
+ - RED
45
+ - NIR_NARROW
46
+ - SWIR_1
47
+ - SWIR_2
48
+ - VV
49
+ - VH
50
+ - CLOUD
51
+ rgb_indices:
52
+ - 4
53
+ - 3
54
+ - 2
55
+ train_data_root: ./../data/regions/combined_uki_spain/images/
56
+ train_label_data_root: ./../data/regions/combined_uki_spain/labels/
57
+ val_data_root: ./../data/regions/combined_uki_spain/images/
58
+ val_label_data_root: ./../data/regions/combined_uki_spain/labels/
59
+ test_data_root: ./../data/regions/combined_uki_spain/images/
60
+ test_label_data_root: ./../data/regions/combined_uki_spain/labels/
61
+ train_split: ./../data/regions/combined_uki_spain/splits/flood_train_data.txt
62
+ test_split: ./../data/regions/combined_uki_spain/splits/flood_test_data.txt
63
+ val_split: ./../data/regions/combined_uki_spain/splits/flood_val_data.txt
64
+ img_grep: "*_image.tif"
65
+ label_grep: "*_label.tif"
66
+ no_label_replace: -1
67
+ no_data_replace: 0
68
+ means:
69
+ - 0.1290484133335582 # BLUE
70
+ - 0.13423481405157794 # GREEN
71
+ - 0.1328938801112928 # RED
72
+ - 0.20036851044035797 # NIR_NARROW
73
+ - 0.13804629743141042 # SWIR_1
74
+ - 0.10409700513471637 # SWIR_2
75
+ - -0.0018052691820029847 # VV
76
+ - -0.0023712696527645486 # VH
77
+ - 0.000024014472961425782 #CLOUD
78
+
79
+ stds:
80
+ - 0.25406999374272976
81
+ - 0.22949378991348005
82
+ - 0.21689414406289836
83
+ - 0.22552362238920548
84
+ - 0.1600542128720416
85
+ - 0.12602917719190815
86
+ - 0.0011294842635096356
87
+ - 0.0008879269711519241
88
+ - 0.00004271712050839232
89
+
90
+ num_classes: 2
91
+
92
+ model:
93
+ class_path: terratorch.tasks.SemanticSegmentationTask
94
+ init_args:
95
+ model_args:
96
+ decoder: FCNDecoder
97
+ backbone_pretrained: true
98
+ backbone: granite_geospatial_uki
99
+ backbone_pretrained_cfg_overlay:
100
+ file: ./../data/checkpoints/granite_geospatial_uki.pt
101
+ backbone_pretrain_img_size: 512
102
+ decoder_channels: 256
103
+ # in_channels: 9
104
+ backbone_bands:
105
+ - BLUE
106
+ - GREEN
107
+ - RED
108
+ - NIR_NARROW
109
+ - SWIR_1
110
+ - SWIR_2
111
+ - VV
112
+ - VH
113
+ - CLOUD
114
+ # num_frames: 1
115
+ num_classes: 2
116
+ head_dropout: 0.1
117
+ decoder_num_convs: 4
118
+ head_channel_list:
119
+ - 256
120
+ necks:
121
+ - name: SelectIndices
122
+ indices:
123
+ - -1
124
+ - name: ReshapeTokensToImage
125
+ loss: ce
126
+ aux_heads:
127
+ - name: aux_head
128
+ decoder: FCNDecoder
129
+ decoder_args:
130
+ decoder_channels: 256
131
+ decoder_in_index: -1
132
+ decoder_num_convs: 2
133
+ head_dropout: 0.1
134
+ aux_loss:
135
+ aux_head: 1.0
136
+ ignore_index: -1
137
+ class_weights:
138
+ - 0.3
139
+ - 0.7
140
+ freeze_backbone: false
141
+ freeze_decoder: false
142
+ model_factory: EncoderDecoderFactory
143
+ tiled_inference_parameters:
144
+ h_crop: 512
145
+ h_stride: 496
146
+ w_crop: 512
147
+ w_stride: 496
148
+ average_patches: true
149
+ optimizer:
150
+ class_path: torch.optim.AdamW
151
+ init_args:
152
+ lr: 6.e-5
153
+ weight_decay: 0.05
154
+ lr_scheduler:
155
+ class_path: ReduceLROnPlateau
156
+ init_args:
157
+ monitor: val/loss