ford442 commited on
Commit
c47663a
·
verified ·
1 Parent(s): 3f9d242

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -18
app.py CHANGED
@@ -92,15 +92,8 @@ def scheduler_swap_callback(pipeline, step_index, timestep, callback_kwargs):
92
  torch.backends.cuda.preferred_blas_library="cublaslt"
93
  if step_index == int(pipeline.num_timesteps * 0.5):
94
  # torch.set_float32_matmul_precision("medium")
95
- callback_kwargs["latents"] = callback_kwargs["latents"].to(torch.float64)
96
- def change_dtype(module):
97
- for child in module.children():
98
- if len(list(child.children())) > 0:
99
- change_dtype(child)
100
- for param in child.parameters():
101
- param.data = param.data.to(torch.float64)
102
-
103
- change_dtype(pipeline.unet)
104
  # pipe.guidance_scale=1.0
105
  # pipe.scheduler.set_timesteps(num_inference_steps*.70)
106
  # print(f"-- setting step {pipeline.num_timesteps * 0.1} --")
@@ -110,15 +103,8 @@ def scheduler_swap_callback(pipeline, step_index, timestep, callback_kwargs):
110
  torch.backends.cudnn.allow_tf32 = False
111
  torch.backends.cuda.matmul.allow_tf32 = False
112
  torch.set_float32_matmul_precision("highest")
113
- callback_kwargs["latents"] = callback_kwargs["latents"].to(torch.bfloat16)
114
- def change_dtype(module):
115
- for child in module.children():
116
- if len(list(child.children())) > 0:
117
- change_dtype(child)
118
- for param in child.parameters():
119
- param.data = param.data.to(torch.bfloat16)
120
-
121
- change_dtype(pipeline.unet)
122
  # pipe.vae = vae_a
123
  # pipe.unet = unet_a
124
  # torch.backends.cudnn.deterministic = False
 
92
  torch.backends.cuda.preferred_blas_library="cublaslt"
93
  if step_index == int(pipeline.num_timesteps * 0.5):
94
  # torch.set_float32_matmul_precision("medium")
95
+ #callback_kwargs["latents"] = callback_kwargs["latents"].to(torch.float64)
96
+ #pipe.unet.to(torch.float64)
 
 
 
 
 
 
 
97
  # pipe.guidance_scale=1.0
98
  # pipe.scheduler.set_timesteps(num_inference_steps*.70)
99
  # print(f"-- setting step {pipeline.num_timesteps * 0.1} --")
 
103
  torch.backends.cudnn.allow_tf32 = False
104
  torch.backends.cuda.matmul.allow_tf32 = False
105
  torch.set_float32_matmul_precision("highest")
106
+ #callback_kwargs["latents"] = callback_kwargs["latents"].to(torch.bfloat16)
107
+ #pipe.unet.to(torch.float64)
 
 
 
 
 
 
 
108
  # pipe.vae = vae_a
109
  # pipe.unet = unet_a
110
  # torch.backends.cudnn.deterministic = False