Yuanshi commited on
Commit
fb17308
·
1 Parent(s): 35e4ce9

1024 support

Browse files
Files changed (2) hide show
  1. app.py +65 -31
  2. src/generate.py +10 -2
app.py CHANGED
@@ -8,11 +8,7 @@ import numpy as np
8
 
9
  from src.generate import seed_everything, generate
10
 
11
- # pipe = None
12
-
13
-
14
- # def init_pipeline():
15
- # global pipe
16
  pipe = FluxPipeline.from_pretrained(
17
  "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
18
  )
@@ -20,12 +16,17 @@ pipe = pipe.to("cuda")
20
  pipe.load_lora_weights(
21
  "Yuanshi/OminiControl",
22
  weight_name=f"omini/subject_512.safetensors",
23
- adapter_name="subject",
 
 
 
 
 
24
  )
25
 
 
26
  @spaces.GPU
27
- def process_image_and_text(image, text):
28
- # center crop image
29
  w, h, min_size = image.size[0], image.size[1], min(image.size)
30
  image = image.crop(
31
  (
@@ -39,16 +40,13 @@ def process_image_and_text(image, text):
39
 
40
  condition = Condition("subject", image)
41
 
42
- # if pipe is None:
43
- # init_pipeline()
44
-
45
  result_img = generate(
46
  pipe,
47
  prompt=text.strip(),
48
  conditions=[condition],
49
  num_inference_steps=8,
50
- height=512,
51
- width=512,
52
  ).images[0]
53
 
54
  return result_img
@@ -58,38 +56,74 @@ def get_samples():
58
  sample_list = [
59
  {
60
  "image": "assets/oranges.jpg",
 
61
  "text": "A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!'",
62
  },
63
  {
64
  "image": "assets/penguin.jpg",
 
65
  "text": "On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat, holding a sign that reads 'Omini Control!'",
66
  },
67
  {
68
  "image": "assets/rc_car.jpg",
 
69
  "text": "A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.",
70
  },
71
  {
72
  "image": "assets/clock.jpg",
 
73
  "text": "In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.",
74
  },
75
  ]
76
- return [[Image.open(sample["image"]), sample["text"]] for sample in sample_list]
77
-
78
-
79
- demo = gr.Interface(
80
- fn=process_image_and_text,
81
- inputs=[
82
- gr.Image(type="pil"),
83
- gr.Textbox(lines=2),
84
- ],
85
- outputs=gr.Image(type="pil"),
86
- title="OminiControl / Subject driven generation",
87
- examples=get_samples(),
88
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  if __name__ == "__main__":
91
- # init_pipeline()
92
- demo.launch(
93
- debug=True,
94
- ssr_mode=False
95
- )
 
8
 
9
  from src.generate import seed_everything, generate
10
 
11
+ pipe = None
 
 
 
 
12
  pipe = FluxPipeline.from_pretrained(
13
  "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
14
  )
 
16
  pipe.load_lora_weights(
17
  "Yuanshi/OminiControl",
18
  weight_name=f"omini/subject_512.safetensors",
19
+ adapter_name="subject_512",
20
+ )
21
+ pipe.load_lora_weights(
22
+ "Yuanshi/OminiControl",
23
+ weight_name=f"omini/subject_1024_beta.safetensors",
24
+ adapter_name="subject_1024",
25
  )
26
 
27
+
28
  @spaces.GPU
29
+ def process_image_and_text(image, resolution, text):
 
30
  w, h, min_size = image.size[0], image.size[1], min(image.size)
31
  image = image.crop(
32
  (
 
40
 
41
  condition = Condition("subject", image)
42
 
 
 
 
43
  result_img = generate(
44
  pipe,
45
  prompt=text.strip(),
46
  conditions=[condition],
47
  num_inference_steps=8,
48
+ height=resolution,
49
+ width=resolution,
50
  ).images[0]
51
 
52
  return result_img
 
56
  sample_list = [
57
  {
58
  "image": "assets/oranges.jpg",
59
+ "resolution": 512,
60
  "text": "A very close up view of this item. It is placed on a wooden table. The background is a dark room, the TV is on, and the screen is showing a cooking show. With text on the screen that reads 'Omini Control!'",
61
  },
62
  {
63
  "image": "assets/penguin.jpg",
64
+ "resolution": 512,
65
  "text": "On Christmas evening, on a crowded sidewalk, this item sits on the road, covered in snow and wearing a Christmas hat, holding a sign that reads 'Omini Control!'",
66
  },
67
  {
68
  "image": "assets/rc_car.jpg",
69
+ "resolution": 1024,
70
  "text": "A film style shot. On the moon, this item drives across the moon surface. The background is that Earth looms large in the foreground.",
71
  },
72
  {
73
  "image": "assets/clock.jpg",
74
+ "resolution": 1024,
75
  "text": "In a Bauhaus style room, this item is placed on a shiny glass table, with a vase of flowers next to it. In the afternoon sun, the shadows of the blinds are cast on the wall.",
76
  },
77
  ]
78
+ return [
79
+ [
80
+ Image.open(sample["image"]).resize((512, 512)),
81
+ sample["resolution"],
82
+ sample["text"],
83
+ ]
84
+ for sample in sample_list
85
+ ]
86
+
87
+
88
+ header = """
89
+ # 🌍 OminiControl / FLUX
90
+
91
+ <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
92
+ <a href="https://arxiv.org/abs/2411.15098"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a>
93
+ <a href="https://huggingface.co/Yuanshi/OminiControl"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a>
94
+ <a href="https://github.com/Yuanshi9815/OminiControl"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a>
95
+ </div>
96
+ """
97
+
98
+
99
+ def create_app():
100
+ with gr.Blocks() as app:
101
+ gr.Markdown(header)
102
+ with gr.Tabs():
103
+ with gr.Tab("Subject-driven"):
104
+ gr.Interface(
105
+ fn=process_image_and_text,
106
+ inputs=[
107
+ gr.Image(type="pil", label="Condition Image", width=300),
108
+ gr.Radio(
109
+ [("512", 512), ("1024(beta)", 1024)],
110
+ label="Resolution",
111
+ value=512,
112
+ ),
113
+ # gr.Slider(4, 16, 4, step=4, label="Inference Steps"),
114
+ gr.Textbox(lines=2, label="Text Prompt"),
115
+ ],
116
+ outputs=gr.Image(type="pil"),
117
+ examples=get_samples(),
118
+ )
119
+ with gr.Tab("Fill"):
120
+ gr.Markdown("Coming soon")
121
+ with gr.Tab("Canny"):
122
+ gr.Markdown("Coming soon")
123
+ with gr.Tab("Depth"):
124
+ gr.Markdown("Coming soon")
125
+ return app
126
+
127
 
128
  if __name__ == "__main__":
129
+ create_app().launch(debug=True, ssr_mode=False)
 
 
 
 
src/generate.py CHANGED
@@ -166,7 +166,12 @@ def generate(
166
  use_condition = conditions is not None or []
167
  if use_condition:
168
  assert len(conditions) <= 1, "Only one condition is supported for now."
169
- pipeline.set_adapters(conditions[0].condition_type)
 
 
 
 
 
170
  for condition in conditions:
171
  tokens, ids, type_id = condition.encode(self)
172
  condition_latents.append(tokens) # [batch_size, token_n, token_dim]
@@ -175,7 +180,10 @@ def generate(
175
  condition_latents = torch.cat(condition_latents, dim=1)
176
  condition_ids = torch.cat(condition_ids, dim=0)
177
  if condition.condition_type == "subject":
178
- condition_ids[:, 2] += width // 16
 
 
 
179
  condition_type_ids = torch.cat(condition_type_ids, dim=0)
180
 
181
  # 5. Prepare timesteps
 
166
  use_condition = conditions is not None or []
167
  if use_condition:
168
  assert len(conditions) <= 1, "Only one condition is supported for now."
169
+ pipeline.set_adapters(
170
+ {
171
+ 512: "subject_512",
172
+ 1024: "subject_1024",
173
+ }[height]
174
+ )
175
  for condition in conditions:
176
  tokens, ids, type_id = condition.encode(self)
177
  condition_latents.append(tokens) # [batch_size, token_n, token_dim]
 
180
  condition_latents = torch.cat(condition_latents, dim=1)
181
  condition_ids = torch.cat(condition_ids, dim=0)
182
  if condition.condition_type == "subject":
183
+ delta = 32 if height == 512 else -32
184
+ # print(f"Condition delta: {delta}")
185
+ condition_ids[:, 2] += delta
186
+
187
  condition_type_ids = torch.cat(condition_type_ids, dim=0)
188
 
189
  # 5. Prepare timesteps