Fabrice-TIERCELIN commited on
Commit
e0eb41e
·
verified ·
1 Parent(s): ac6e4f8

Upload 3 files

Browse files
Files changed (3) hide show
  1. sgm/__init__.py +4 -0
  2. sgm/lr_scheduler.py +135 -0
  3. sgm/util.py +248 -0
sgm/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .models import AutoencodingEngine, DiffusionEngine
2
+ from .util import get_configs_path, instantiate_from_config
3
+
4
+ __version__ = "0.1.0"
sgm/lr_scheduler.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+
9
+ def __init__(
10
+ self,
11
+ warm_up_steps,
12
+ lr_min,
13
+ lr_max,
14
+ lr_start,
15
+ max_decay_steps,
16
+ verbosity_interval=0,
17
+ ):
18
+ self.lr_warm_up_steps = warm_up_steps
19
+ self.lr_start = lr_start
20
+ self.lr_min = lr_min
21
+ self.lr_max = lr_max
22
+ self.lr_max_decay_steps = max_decay_steps
23
+ self.last_lr = 0.0
24
+ self.verbosity_interval = verbosity_interval
25
+
26
+ def schedule(self, n, **kwargs):
27
+ if self.verbosity_interval > 0:
28
+ if n % self.verbosity_interval == 0:
29
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
30
+ if n < self.lr_warm_up_steps:
31
+ lr = (
32
+ self.lr_max - self.lr_start
33
+ ) / self.lr_warm_up_steps * n + self.lr_start
34
+ self.last_lr = lr
35
+ return lr
36
+ else:
37
+ t = (n - self.lr_warm_up_steps) / (
38
+ self.lr_max_decay_steps - self.lr_warm_up_steps
39
+ )
40
+ t = min(t, 1.0)
41
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
42
+ 1 + np.cos(t * np.pi)
43
+ )
44
+ self.last_lr = lr
45
+ return lr
46
+
47
+ def __call__(self, n, **kwargs):
48
+ return self.schedule(n, **kwargs)
49
+
50
+
51
+ class LambdaWarmUpCosineScheduler2:
52
+ """
53
+ supports repeated iterations, configurable via lists
54
+ note: use with a base_lr of 1.0.
55
+ """
56
+
57
+ def __init__(
58
+ self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
59
+ ):
60
+ assert (
61
+ len(warm_up_steps)
62
+ == len(f_min)
63
+ == len(f_max)
64
+ == len(f_start)
65
+ == len(cycle_lengths)
66
+ )
67
+ self.lr_warm_up_steps = warm_up_steps
68
+ self.f_start = f_start
69
+ self.f_min = f_min
70
+ self.f_max = f_max
71
+ self.cycle_lengths = cycle_lengths
72
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
73
+ self.last_f = 0.0
74
+ self.verbosity_interval = verbosity_interval
75
+
76
+ def find_in_interval(self, n):
77
+ interval = 0
78
+ for cl in self.cum_cycles[1:]:
79
+ if n <= cl:
80
+ return interval
81
+ interval += 1
82
+
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0:
88
+ print(
89
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
90
+ f"current cycle {cycle}"
91
+ )
92
+ if n < self.lr_warm_up_steps[cycle]:
93
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
94
+ cycle
95
+ ] * n + self.f_start[cycle]
96
+ self.last_f = f
97
+ return f
98
+ else:
99
+ t = (n - self.lr_warm_up_steps[cycle]) / (
100
+ self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
101
+ )
102
+ t = min(t, 1.0)
103
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
104
+ 1 + np.cos(t * np.pi)
105
+ )
106
+ self.last_f = f
107
+ return f
108
+
109
+ def __call__(self, n, **kwargs):
110
+ return self.schedule(n, **kwargs)
111
+
112
+
113
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
114
+ def schedule(self, n, **kwargs):
115
+ cycle = self.find_in_interval(n)
116
+ n = n - self.cum_cycles[cycle]
117
+ if self.verbosity_interval > 0:
118
+ if n % self.verbosity_interval == 0:
119
+ print(
120
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
121
+ f"current cycle {cycle}"
122
+ )
123
+
124
+ if n < self.lr_warm_up_steps[cycle]:
125
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
126
+ cycle
127
+ ] * n + self.f_start[cycle]
128
+ self.last_f = f
129
+ return f
130
+ else:
131
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
132
+ self.cycle_lengths[cycle] - n
133
+ ) / (self.cycle_lengths[cycle])
134
+ self.last_f = f
135
+ return f
sgm/util.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import importlib
3
+ import os
4
+ from functools import partial
5
+ from inspect import isfunction
6
+
7
+ import fsspec
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image, ImageDraw, ImageFont
11
+ from safetensors.torch import load_file as load_safetensors
12
+
13
+
14
+ def disabled_train(self, mode=True):
15
+ """Overwrite model.train with this function to make sure train/eval mode
16
+ does not change anymore."""
17
+ return self
18
+
19
+
20
+ def get_string_from_tuple(s):
21
+ try:
22
+ # Check if the string starts and ends with parentheses
23
+ if s[0] == "(" and s[-1] == ")":
24
+ # Convert the string to a tuple
25
+ t = eval(s)
26
+ # Check if the type of t is tuple
27
+ if type(t) == tuple:
28
+ return t[0]
29
+ else:
30
+ pass
31
+ except:
32
+ pass
33
+ return s
34
+
35
+
36
+ def is_power_of_two(n):
37
+ """
38
+ chat.openai.com/chat
39
+ Return True if n is a power of 2, otherwise return False.
40
+
41
+ The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.
42
+ The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.
43
+ If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.
44
+ Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.
45
+
46
+ """
47
+ if n <= 0:
48
+ return False
49
+ return (n & (n - 1)) == 0
50
+
51
+
52
+ def autocast(f, enabled=True):
53
+ def do_autocast(*args, **kwargs):
54
+ with torch.cuda.amp.autocast(
55
+ enabled=enabled,
56
+ dtype=torch.get_autocast_gpu_dtype(),
57
+ cache_enabled=torch.is_autocast_cache_enabled(),
58
+ ):
59
+ return f(*args, **kwargs)
60
+
61
+ return do_autocast
62
+
63
+
64
+ def load_partial_from_config(config):
65
+ return partial(get_obj_from_str(config["target"]), **config.get("params", dict()))
66
+
67
+
68
+ def log_txt_as_img(wh, xc, size=10):
69
+ # wh a tuple of (width, height)
70
+ # xc a list of captions to plot
71
+ b = len(xc)
72
+ txts = list()
73
+ for bi in range(b):
74
+ txt = Image.new("RGB", wh, color="white")
75
+ draw = ImageDraw.Draw(txt)
76
+ font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
77
+ nc = int(40 * (wh[0] / 256))
78
+ if isinstance(xc[bi], list):
79
+ text_seq = xc[bi][0]
80
+ else:
81
+ text_seq = xc[bi]
82
+ lines = "\n".join(
83
+ text_seq[start : start + nc] for start in range(0, len(text_seq), nc)
84
+ )
85
+
86
+ try:
87
+ draw.text((0, 0), lines, fill="black", font=font)
88
+ except UnicodeEncodeError:
89
+ print("Cant encode string for logging. Skipping.")
90
+
91
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
92
+ txts.append(txt)
93
+ txts = np.stack(txts)
94
+ txts = torch.tensor(txts)
95
+ return txts
96
+
97
+
98
+ def partialclass(cls, *args, **kwargs):
99
+ class NewCls(cls):
100
+ __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
101
+
102
+ return NewCls
103
+
104
+
105
+ def make_path_absolute(path):
106
+ fs, p = fsspec.core.url_to_fs(path)
107
+ if fs.protocol == "file":
108
+ return os.path.abspath(p)
109
+ return path
110
+
111
+
112
+ def ismap(x):
113
+ if not isinstance(x, torch.Tensor):
114
+ return False
115
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
116
+
117
+
118
+ def isimage(x):
119
+ if not isinstance(x, torch.Tensor):
120
+ return False
121
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
122
+
123
+
124
+ def isheatmap(x):
125
+ if not isinstance(x, torch.Tensor):
126
+ return False
127
+
128
+ return x.ndim == 2
129
+
130
+
131
+ def isneighbors(x):
132
+ if not isinstance(x, torch.Tensor):
133
+ return False
134
+ return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)
135
+
136
+
137
+ def exists(x):
138
+ return x is not None
139
+
140
+
141
+ def expand_dims_like(x, y):
142
+ while x.dim() != y.dim():
143
+ x = x.unsqueeze(-1)
144
+ return x
145
+
146
+
147
+ def default(val, d):
148
+ if exists(val):
149
+ return val
150
+ return d() if isfunction(d) else d
151
+
152
+
153
+ def mean_flat(tensor):
154
+ """
155
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
156
+ Take the mean over all non-batch dimensions.
157
+ """
158
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
159
+
160
+
161
+ def count_params(model, verbose=False):
162
+ total_params = sum(p.numel() for p in model.parameters())
163
+ if verbose:
164
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
165
+ return total_params
166
+
167
+
168
+ def instantiate_from_config(config):
169
+ if not "target" in config:
170
+ if config == "__is_first_stage__":
171
+ return None
172
+ elif config == "__is_unconditional__":
173
+ return None
174
+ raise KeyError("Expected key `target` to instantiate.")
175
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
176
+
177
+
178
+ def get_obj_from_str(string, reload=False, invalidate_cache=True):
179
+ module, cls = string.rsplit(".", 1)
180
+ if invalidate_cache:
181
+ importlib.invalidate_caches()
182
+ if reload:
183
+ module_imp = importlib.import_module(module)
184
+ importlib.reload(module_imp)
185
+ return getattr(importlib.import_module(module, package=None), cls)
186
+
187
+
188
+ def append_zero(x):
189
+ return torch.cat([x, x.new_zeros([1])])
190
+
191
+
192
+ def append_dims(x, target_dims):
193
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
194
+ dims_to_append = target_dims - x.ndim
195
+ if dims_to_append < 0:
196
+ raise ValueError(
197
+ f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
198
+ )
199
+ return x[(...,) + (None,) * dims_to_append]
200
+
201
+
202
+ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
203
+ print(f"Loading model from {ckpt}")
204
+ if ckpt.endswith("ckpt"):
205
+ pl_sd = torch.load(ckpt, map_location="cpu")
206
+ if "global_step" in pl_sd:
207
+ print(f"Global Step: {pl_sd['global_step']}")
208
+ sd = pl_sd["state_dict"]
209
+ elif ckpt.endswith("safetensors"):
210
+ sd = load_safetensors(ckpt)
211
+ else:
212
+ raise NotImplementedError
213
+
214
+ model = instantiate_from_config(config.model)
215
+
216
+ m, u = model.load_state_dict(sd, strict=False)
217
+
218
+ if len(m) > 0 and verbose:
219
+ print("missing keys:")
220
+ print(m)
221
+ if len(u) > 0 and verbose:
222
+ print("unexpected keys:")
223
+ print(u)
224
+
225
+ if freeze:
226
+ for param in model.parameters():
227
+ param.requires_grad = False
228
+
229
+ model.eval()
230
+ return model
231
+
232
+
233
+ def get_configs_path() -> str:
234
+ """
235
+ Get the `configs` directory.
236
+ For a working copy, this is the one in the root of the repository,
237
+ but for an installed copy, it's in the `sgm` package (see pyproject.toml).
238
+ """
239
+ this_dir = os.path.dirname(__file__)
240
+ candidates = (
241
+ os.path.join(this_dir, "configs"),
242
+ os.path.join(this_dir, "..", "configs"),
243
+ )
244
+ for candidate in candidates:
245
+ candidate = os.path.abspath(candidate)
246
+ if os.path.isdir(candidate):
247
+ return candidate
248
+ raise FileNotFoundError(f"Could not find SGM configs in {candidates}")