jokerbit commited on
Commit
194a53d
·
verified ·
1 Parent(s): 802934b

more compile

Browse files
Files changed (1) hide show
  1. src/pipeline.py +3 -3
src/pipeline.py CHANGED
@@ -52,9 +52,9 @@ def load_pipeline() -> Pipeline:
52
  pipeline.to(memory_format=torch.channels_last)
53
  quantize_(pipeline.vae, int8_weight_only())
54
  pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune", fullgraph=True)
55
-
56
- for _ in range(2):
57
- pipeline("cat", num_inference_steps=4)
58
 
59
  return pipeline
60
 
 
52
  pipeline.to(memory_format=torch.channels_last)
53
  quantize_(pipeline.vae, int8_weight_only())
54
  pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune", fullgraph=True)
55
+ with torch.inference_mode():
56
+ for _ in range(2):
57
+ pipeline("cat", num_inference_steps=4)
58
 
59
  return pipeline
60