Yuanshi commited on
Commit
35e4ce9
·
verified ·
1 Parent(s): 9ce9f52
Files changed (1) hide show
  1. app.py +17 -16
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import torch
 
3
  from PIL import Image, ImageDraw, ImageFont
4
  from src.condition import Condition
5
  from diffusers.pipelines import FluxPipeline
@@ -7,22 +8,22 @@ import numpy as np
7
 
8
  from src.generate import seed_everything, generate
9
 
10
- pipe = None
11
 
12
 
13
- def init_pipeline():
14
- global pipe
15
- pipe = FluxPipeline.from_pretrained(
16
- "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
17
- )
18
- pipe = pipe.to("cuda")
19
- pipe.load_lora_weights(
20
- "Yuanshi/OminiControl",
21
- weight_name=f"omini/subject_512.safetensors",
22
- adapter_name="subject",
23
- )
24
-
25
 
 
26
  def process_image_and_text(image, text):
27
  # center crop image
28
  w, h, min_size = image.size[0], image.size[1], min(image.size)
@@ -38,8 +39,8 @@ def process_image_and_text(image, text):
38
 
39
  condition = Condition("subject", image)
40
 
41
- if pipe is None:
42
- init_pipeline()
43
 
44
  result_img = generate(
45
  pipe,
@@ -87,7 +88,7 @@ demo = gr.Interface(
87
  )
88
 
89
  if __name__ == "__main__":
90
- init_pipeline()
91
  demo.launch(
92
  debug=True,
93
  ssr_mode=False
 
1
  import gradio as gr
2
  import torch
3
+ import spaces
4
  from PIL import Image, ImageDraw, ImageFont
5
  from src.condition import Condition
6
  from diffusers.pipelines import FluxPipeline
 
8
 
9
  from src.generate import seed_everything, generate
10
 
11
+ # pipe = None
12
 
13
 
14
+ # def init_pipeline():
15
+ # global pipe
16
+ pipe = FluxPipeline.from_pretrained(
17
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
18
+ )
19
+ pipe = pipe.to("cuda")
20
+ pipe.load_lora_weights(
21
+ "Yuanshi/OminiControl",
22
+ weight_name=f"omini/subject_512.safetensors",
23
+ adapter_name="subject",
24
+ )
 
25
 
26
+ @spaces.GPU
27
  def process_image_and_text(image, text):
28
  # center crop image
29
  w, h, min_size = image.size[0], image.size[1], min(image.size)
 
39
 
40
  condition = Condition("subject", image)
41
 
42
+ # if pipe is None:
43
+ # init_pipeline()
44
 
45
  result_img = generate(
46
  pipe,
 
88
  )
89
 
90
  if __name__ == "__main__":
91
+ # init_pipeline()
92
  demo.launch(
93
  debug=True,
94
  ssr_mode=False