daniel shalem commited on
Commit
1940326
·
1 Parent(s): 645fba0

Feature: Add mixed precision support and direct bfloat16 support.

Browse files
xora/examples/image_to_video.py CHANGED
@@ -136,6 +136,12 @@ def main():
136
  "--frame_rate", type=int, default=25, help="Frame rate for the output video"
137
  )
138
 
 
 
 
 
 
 
139
  # Prompts
140
  parser.add_argument(
141
  "--prompt",
@@ -224,6 +230,7 @@ def main():
224
  is_video=True,
225
  vae_per_channel_normalize=True,
226
  conditioning_method=ConditioningMethod.FIRST_FRAME,
 
227
  ).images
228
 
229
  # Save output video
 
136
  "--frame_rate", type=int, default=25, help="Frame rate for the output video"
137
  )
138
 
139
+ parser.add_argument(
140
+ "--mixed_precision",
141
+ action="store_true",
142
+ help="Mixed precision in float32 and bfloat16",
143
+ )
144
+
145
  # Prompts
146
  parser.add_argument(
147
  "--prompt",
 
230
  is_video=True,
231
  vae_per_channel_normalize=True,
232
  conditioning_method=ConditioningMethod.FIRST_FRAME,
233
+ mixed_precision=args.mixed_precision,
234
  ).images
235
 
236
  # Save output video
xora/models/transformers/transformer3d.py CHANGED
@@ -305,7 +305,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
305
  sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
306
  cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
307
  sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
308
- return cos_freq, sin_freq
309
 
310
  def forward(
311
  self,
 
305
  sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
306
  cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
307
  sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
308
+ return cos_freq.to(dtype), sin_freq.to(dtype)
309
 
310
  def forward(
311
  self,
xora/pipelines/pipeline_xora_video.py CHANGED
@@ -9,6 +9,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
9
 
10
  import torch
11
  import torch.nn.functional as F
 
12
  from diffusers.image_processor import VaeImageProcessor
13
  from diffusers.models import AutoencoderKL
14
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -758,6 +759,7 @@ class XoraVideoPipeline(DiffusionPipeline):
758
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
759
  clean_caption: bool = True,
760
  media_items: Optional[torch.FloatTensor] = None,
 
761
  **kwargs,
762
  ) -> Union[ImagePipelineOutput, Tuple]:
763
  """
@@ -1006,16 +1008,22 @@ class XoraVideoPipeline(DiffusionPipeline):
1006
 
1007
  if conditioning_mask is not None:
1008
  current_timestep = current_timestep * (1 - conditioning_mask)
 
 
 
 
 
1009
 
1010
  # predict noise model_output
1011
- noise_pred = self.transformer(
1012
- latent_model_input.to(self.transformer.dtype),
1013
- indices_grid,
1014
- encoder_hidden_states=prompt_embeds.to(self.transformer.dtype),
1015
- encoder_attention_mask=prompt_attention_mask,
1016
- timestep=current_timestep,
1017
- return_dict=False,
1018
- )[0]
 
1019
 
1020
  # perform guidance
1021
  if do_classifier_free_guidance:
 
9
 
10
  import torch
11
  import torch.nn.functional as F
12
+ from contextlib import nullcontext
13
  from diffusers.image_processor import VaeImageProcessor
14
  from diffusers.models import AutoencoderKL
15
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
 
759
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
760
  clean_caption: bool = True,
761
  media_items: Optional[torch.FloatTensor] = None,
762
+ mixed_precision: bool = False,
763
  **kwargs,
764
  ) -> Union[ImagePipelineOutput, Tuple]:
765
  """
 
1008
 
1009
  if conditioning_mask is not None:
1010
  current_timestep = current_timestep * (1 - conditioning_mask)
1011
+ # Choose the appropriate context manager based on `mixed_precision`
1012
+ if mixed_precision:
1013
+ context_manager = torch.autocast("cuda", dtype=torch.bfloat16)
1014
+ else:
1015
+ context_manager = nullcontext() # Dummy context manager
1016
 
1017
  # predict noise model_output
1018
+ with context_manager:
1019
+ noise_pred = self.transformer(
1020
+ latent_model_input.to(self.transformer.dtype),
1021
+ indices_grid,
1022
+ encoder_hidden_states=prompt_embeds.to(self.transformer.dtype),
1023
+ encoder_attention_mask=prompt_attention_mask,
1024
+ timestep=current_timestep,
1025
+ return_dict=False,
1026
+ )[0]
1027
 
1028
  # perform guidance
1029
  if do_classifier_free_guidance: