wondervictor commited on
Commit
6c9f9c3
·
verified ·
1 Parent(s): 6f18588

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +4 -4
model.py CHANGED
@@ -74,8 +74,8 @@ class Model:
74
 
75
  def load_gpt(self, condition_type='canny'):
76
  gpt_ckpt = models[condition_type]
77
- # precision = torch.bfloat16
78
- precision = torch.float32
79
  latent_size = 768 // 16
80
  gpt_model = GPT_models["GPT-XL"](
81
  block_size=latent_size**2,
@@ -93,8 +93,8 @@ class Model:
93
  return gpt_model
94
 
95
  def load_t5(self):
96
- # precision = torch.bfloat16
97
- precision = torch.float32
98
  t5_model = T5Embedder(
99
  device=self.device,
100
  local_cache=True,
 
74
 
75
  def load_gpt(self, condition_type='canny'):
76
  gpt_ckpt = models[condition_type]
77
+ precision = torch.bfloat16
78
+ # precision = torch.float32
79
  latent_size = 768 // 16
80
  gpt_model = GPT_models["GPT-XL"](
81
  block_size=latent_size**2,
 
93
  return gpt_model
94
 
95
  def load_t5(self):
96
+ precision = torch.bfloat16
97
+ # precision = torch.float32
98
  t5_model = T5Embedder(
99
  device=self.device,
100
  local_cache=True,