Fabrice-TIERCELIN commited on
Commit
ef85e7a
·
verified ·
1 Parent(s): 4dfbfa2

Upload 2 files

Browse files
hyvideo/diffusion/schedulers/__init__.py CHANGED
@@ -1 +1 @@
1
- from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
 
1
+ from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py CHANGED
@@ -1,257 +1,257 @@
1
- # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- #
16
- # Modified from diffusers==0.29.2
17
- #
18
- # ==============================================================================
19
-
20
- from dataclasses import dataclass
21
- from typing import Optional, Tuple, Union
22
-
23
- import numpy as np
24
- import torch
25
-
26
- from diffusers.configuration_utils import ConfigMixin, register_to_config
27
- from diffusers.utils import BaseOutput, logging
28
- from diffusers.schedulers.scheduling_utils import SchedulerMixin
29
-
30
-
31
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
-
33
-
34
- @dataclass
35
- class FlowMatchDiscreteSchedulerOutput(BaseOutput):
36
- """
37
- Output class for the scheduler's `step` function output.
38
-
39
- Args:
40
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
- Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
42
- denoising loop.
43
- """
44
-
45
- prev_sample: torch.FloatTensor
46
-
47
-
48
- class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
49
- """
50
- Euler scheduler.
51
-
52
- This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
53
- methods the library implements for all schedulers such as loading and saving.
54
-
55
- Args:
56
- num_train_timesteps (`int`, defaults to 1000):
57
- The number of diffusion steps to train the model.
58
- timestep_spacing (`str`, defaults to `"linspace"`):
59
- The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
60
- Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
61
- shift (`float`, defaults to 1.0):
62
- The shift value for the timestep schedule.
63
- reverse (`bool`, defaults to `True`):
64
- Whether to reverse the timestep schedule.
65
- """
66
-
67
- _compatibles = []
68
- order = 1
69
-
70
- @register_to_config
71
- def __init__(
72
- self,
73
- num_train_timesteps: int = 1000,
74
- shift: float = 1.0,
75
- reverse: bool = True,
76
- solver: str = "euler",
77
- n_tokens: Optional[int] = None,
78
- ):
79
- sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
80
-
81
- if not reverse:
82
- sigmas = sigmas.flip(0)
83
-
84
- self.sigmas = sigmas
85
- # the value fed to model
86
- self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
87
-
88
- self._step_index = None
89
- self._begin_index = None
90
-
91
- self.supported_solver = ["euler"]
92
- if solver not in self.supported_solver:
93
- raise ValueError(
94
- f"Solver {solver} not supported. Supported solvers: {self.supported_solver}"
95
- )
96
-
97
- @property
98
- def step_index(self):
99
- """
100
- The index counter for current timestep. It will increase 1 after each scheduler step.
101
- """
102
- return self._step_index
103
-
104
- @property
105
- def begin_index(self):
106
- """
107
- The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
108
- """
109
- return self._begin_index
110
-
111
- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
112
- def set_begin_index(self, begin_index: int = 0):
113
- """
114
- Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
115
-
116
- Args:
117
- begin_index (`int`):
118
- The begin index for the scheduler.
119
- """
120
- self._begin_index = begin_index
121
-
122
- def _sigma_to_t(self, sigma):
123
- return sigma * self.config.num_train_timesteps
124
-
125
- def set_timesteps(
126
- self,
127
- num_inference_steps: int,
128
- device: Union[str, torch.device] = None,
129
- n_tokens: int = None,
130
- ):
131
- """
132
- Sets the discrete timesteps used for the diffusion chain (to be run before inference).
133
-
134
- Args:
135
- num_inference_steps (`int`):
136
- The number of diffusion steps used when generating samples with a pre-trained model.
137
- device (`str` or `torch.device`, *optional*):
138
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
139
- n_tokens (`int`, *optional*):
140
- Number of tokens in the input sequence.
141
- """
142
- self.num_inference_steps = num_inference_steps
143
-
144
- sigmas = torch.linspace(1, 0, num_inference_steps + 1)
145
- sigmas = self.sd3_time_shift(sigmas)
146
-
147
- if not self.config.reverse:
148
- sigmas = 1 - sigmas
149
-
150
- self.sigmas = sigmas
151
- self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(
152
- dtype=torch.float32, device=device
153
- )
154
-
155
- # Reset step index
156
- self._step_index = None
157
-
158
- def index_for_timestep(self, timestep, schedule_timesteps=None):
159
- if schedule_timesteps is None:
160
- schedule_timesteps = self.timesteps
161
-
162
- indices = (schedule_timesteps == timestep).nonzero()
163
-
164
- # The sigma index that is taken for the **very** first `step`
165
- # is always the second index (or the last index if there is only 1)
166
- # This way we can ensure we don't accidentally skip a sigma in
167
- # case we start in the middle of the denoising schedule (e.g. for image-to-image)
168
- pos = 1 if len(indices) > 1 else 0
169
-
170
- return indices[pos].item()
171
-
172
- def _init_step_index(self, timestep):
173
- if self.begin_index is None:
174
- if isinstance(timestep, torch.Tensor):
175
- timestep = timestep.to(self.timesteps.device)
176
- self._step_index = self.index_for_timestep(timestep)
177
- else:
178
- self._step_index = self._begin_index
179
-
180
- def scale_model_input(
181
- self, sample: torch.Tensor, timestep: Optional[int] = None
182
- ) -> torch.Tensor:
183
- return sample
184
-
185
- def sd3_time_shift(self, t: torch.Tensor):
186
- return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
187
-
188
- def step(
189
- self,
190
- model_output: torch.FloatTensor,
191
- timestep: Union[float, torch.FloatTensor],
192
- sample: torch.FloatTensor,
193
- return_dict: bool = True,
194
- ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
195
- """
196
- Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
197
- process from the learned model outputs (most often the predicted noise).
198
-
199
- Args:
200
- model_output (`torch.FloatTensor`):
201
- The direct output from learned diffusion model.
202
- timestep (`float`):
203
- The current discrete timestep in the diffusion chain.
204
- sample (`torch.FloatTensor`):
205
- A current instance of a sample created by the diffusion process.
206
- generator (`torch.Generator`, *optional*):
207
- A random number generator.
208
- n_tokens (`int`, *optional*):
209
- Number of tokens in the input sequence.
210
- return_dict (`bool`):
211
- Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
212
- tuple.
213
-
214
- Returns:
215
- [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
216
- If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
217
- returned, otherwise a tuple is returned where the first element is the sample tensor.
218
- """
219
-
220
- if (
221
- isinstance(timestep, int)
222
- or isinstance(timestep, torch.IntTensor)
223
- or isinstance(timestep, torch.LongTensor)
224
- ):
225
- raise ValueError(
226
- (
227
- "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
228
- " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
229
- " one of the `scheduler.timesteps` as a timestep."
230
- ),
231
- )
232
-
233
- if self.step_index is None:
234
- self._init_step_index(timestep)
235
-
236
- # Upcast to avoid precision issues when computing prev_sample
237
- sample = sample.to(torch.float32)
238
-
239
- dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
240
-
241
- if self.config.solver == "euler":
242
- prev_sample = sample + model_output.to(torch.float32) * dt
243
- else:
244
- raise ValueError(
245
- f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}"
246
- )
247
-
248
- # upon completion increase step index by one
249
- self._step_index += 1
250
-
251
- if not return_dict:
252
- return (prev_sample,)
253
-
254
- return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
255
-
256
- def __len__(self):
257
- return self.config.num_train_timesteps
 
1
+ # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+
20
+ from dataclasses import dataclass
21
+ from typing import Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+
26
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
27
+ from diffusers.utils import BaseOutput, logging
28
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ @dataclass
35
+ class FlowMatchDiscreteSchedulerOutput(BaseOutput):
36
+ """
37
+ Output class for the scheduler's `step` function output.
38
+
39
+ Args:
40
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
42
+ denoising loop.
43
+ """
44
+
45
+ prev_sample: torch.FloatTensor
46
+
47
+
48
+ class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
49
+ """
50
+ Euler scheduler.
51
+
52
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
53
+ methods the library implements for all schedulers such as loading and saving.
54
+
55
+ Args:
56
+ num_train_timesteps (`int`, defaults to 1000):
57
+ The number of diffusion steps to train the model.
58
+ timestep_spacing (`str`, defaults to `"linspace"`):
59
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
60
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
61
+ shift (`float`, defaults to 1.0):
62
+ The shift value for the timestep schedule.
63
+ reverse (`bool`, defaults to `True`):
64
+ Whether to reverse the timestep schedule.
65
+ """
66
+
67
+ _compatibles = []
68
+ order = 1
69
+
70
+ @register_to_config
71
+ def __init__(
72
+ self,
73
+ num_train_timesteps: int = 1000,
74
+ shift: float = 1.0,
75
+ reverse: bool = True,
76
+ solver: str = "euler",
77
+ n_tokens: Optional[int] = None,
78
+ ):
79
+ sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
80
+
81
+ if not reverse:
82
+ sigmas = sigmas.flip(0)
83
+
84
+ self.sigmas = sigmas
85
+ # the value fed to model
86
+ self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
87
+
88
+ self._step_index = None
89
+ self._begin_index = None
90
+
91
+ self.supported_solver = ["euler"]
92
+ if solver not in self.supported_solver:
93
+ raise ValueError(
94
+ f"Solver {solver} not supported. Supported solvers: {self.supported_solver}"
95
+ )
96
+
97
+ @property
98
+ def step_index(self):
99
+ """
100
+ The index counter for current timestep. It will increase 1 after each scheduler step.
101
+ """
102
+ return self._step_index
103
+
104
+ @property
105
+ def begin_index(self):
106
+ """
107
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
108
+ """
109
+ return self._begin_index
110
+
111
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
112
+ def set_begin_index(self, begin_index: int = 0):
113
+ """
114
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
115
+
116
+ Args:
117
+ begin_index (`int`):
118
+ The begin index for the scheduler.
119
+ """
120
+ self._begin_index = begin_index
121
+
122
+ def _sigma_to_t(self, sigma):
123
+ return sigma * self.config.num_train_timesteps
124
+
125
+ def set_timesteps(
126
+ self,
127
+ num_inference_steps: int,
128
+ device: Union[str, torch.device] = None,
129
+ n_tokens: int = None,
130
+ ):
131
+ """
132
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
133
+
134
+ Args:
135
+ num_inference_steps (`int`):
136
+ The number of diffusion steps used when generating samples with a pre-trained model.
137
+ device (`str` or `torch.device`, *optional*):
138
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
139
+ n_tokens (`int`, *optional*):
140
+ Number of tokens in the input sequence.
141
+ """
142
+ self.num_inference_steps = num_inference_steps
143
+
144
+ sigmas = torch.linspace(1, 0, num_inference_steps + 1)
145
+ sigmas = self.sd3_time_shift(sigmas)
146
+
147
+ if not self.config.reverse:
148
+ sigmas = 1 - sigmas
149
+
150
+ self.sigmas = sigmas
151
+ self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(
152
+ dtype=torch.float32, device=device
153
+ )
154
+
155
+ # Reset step index
156
+ self._step_index = None
157
+
158
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
159
+ if schedule_timesteps is None:
160
+ schedule_timesteps = self.timesteps
161
+
162
+ indices = (schedule_timesteps == timestep).nonzero()
163
+
164
+ # The sigma index that is taken for the **very** first `step`
165
+ # is always the second index (or the last index if there is only 1)
166
+ # This way we can ensure we don't accidentally skip a sigma in
167
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
168
+ pos = 1 if len(indices) > 1 else 0
169
+
170
+ return indices[pos].item()
171
+
172
+ def _init_step_index(self, timestep):
173
+ if self.begin_index is None:
174
+ if isinstance(timestep, torch.Tensor):
175
+ timestep = timestep.to(self.timesteps.device)
176
+ self._step_index = self.index_for_timestep(timestep)
177
+ else:
178
+ self._step_index = self._begin_index
179
+
180
+ def scale_model_input(
181
+ self, sample: torch.Tensor, timestep: Optional[int] = None
182
+ ) -> torch.Tensor:
183
+ return sample
184
+
185
+ def sd3_time_shift(self, t: torch.Tensor):
186
+ return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
187
+
188
+ def step(
189
+ self,
190
+ model_output: torch.FloatTensor,
191
+ timestep: Union[float, torch.FloatTensor],
192
+ sample: torch.FloatTensor,
193
+ return_dict: bool = True,
194
+ ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
195
+ """
196
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
197
+ process from the learned model outputs (most often the predicted noise).
198
+
199
+ Args:
200
+ model_output (`torch.FloatTensor`):
201
+ The direct output from learned diffusion model.
202
+ timestep (`float`):
203
+ The current discrete timestep in the diffusion chain.
204
+ sample (`torch.FloatTensor`):
205
+ A current instance of a sample created by the diffusion process.
206
+ generator (`torch.Generator`, *optional*):
207
+ A random number generator.
208
+ n_tokens (`int`, *optional*):
209
+ Number of tokens in the input sequence.
210
+ return_dict (`bool`):
211
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
212
+ tuple.
213
+
214
+ Returns:
215
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
216
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
217
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
218
+ """
219
+
220
+ if (
221
+ isinstance(timestep, int)
222
+ or isinstance(timestep, torch.IntTensor)
223
+ or isinstance(timestep, torch.LongTensor)
224
+ ):
225
+ raise ValueError(
226
+ (
227
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
228
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
229
+ " one of the `scheduler.timesteps` as a timestep."
230
+ ),
231
+ )
232
+
233
+ if self.step_index is None:
234
+ self._init_step_index(timestep)
235
+
236
+ # Upcast to avoid precision issues when computing prev_sample
237
+ sample = sample.to(torch.float32)
238
+
239
+ dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
240
+
241
+ if self.config.solver == "euler":
242
+ prev_sample = sample + model_output.to(torch.float32) * dt
243
+ else:
244
+ raise ValueError(
245
+ f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}"
246
+ )
247
+
248
+ # upon completion increase step index by one
249
+ self._step_index += 1
250
+
251
+ if not return_dict:
252
+ return (prev_sample,)
253
+
254
+ return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
255
+
256
+ def __len__(self):
257
+ return self.config.num_train_timesteps