ysharma HF staff commited on
Commit
eed5e87
·
verified ·
1 Parent(s): cc59016

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +219 -0
app.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import uuid
3
+ import random
4
+ import numpy as np
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from PIL import Image
9
+ from diffusers import FluxInpaintPipeline
10
+
11
+ from gradio_client import Client, handle_file
12
+ from PIL import Image
13
+
14
+ # Set an environment variable
15
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
16
+
17
+ MARKDOWN = """
18
+ # FLUX.1 Inpainting with Text guided Mask🔥
19
+ Shoutout to [Black Forest Labs](https://huggingface.co/black-forest-labs) team for
20
+ creating this amazing model, and a big thanks to [Gothos](https://github.com/Gothos)
21
+ for taking it to the next level by enabling inpainting with the FLUX.
22
+ """
23
+
24
+ MAX_SEED = np.iinfo(np.int32).max
25
+ MAX_IMAGE_SIZE = 2048
26
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
+
28
+ # Using Gradio Python Client to query EVF-SAM demo, hosted on SPaces, as an endpoint
29
+ client = Client("ysharma/evf-sam", hf_token=HF_TOKEN)
30
+
31
+
32
+ pipe = FluxInpaintPipeline.from_pretrained(
33
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
34
+
35
+
36
+ def resize_image_dimensions(
37
+ original_resolution_wh: Tuple[int, int],
38
+ maximum_dimension: int = 2048
39
+ ) -> Tuple[int, int]:
40
+ width, height = original_resolution_wh
41
+
42
+ if width <= maximum_dimension and height <= maximum_dimension:
43
+ width = width - (width % 32)
44
+ height = height - (height % 32)
45
+ return width, height
46
+
47
+ if width > height:
48
+ scaling_factor = maximum_dimension / width
49
+ else:
50
+ scaling_factor = maximum_dimension / height
51
+
52
+ new_width = int(width * scaling_factor)
53
+ new_height = int(height * scaling_factor)
54
+
55
+ new_width = new_width - (new_width % 32)
56
+ new_height = new_height - (new_height % 32)
57
+
58
+ return new_width, new_height
59
+
60
+
61
+ def evf_sam_mask(image, prompt):
62
+ print(type(image))
63
+ filename=str(uuid.uuid4()) + ".jpg"
64
+ image.save(filename)
65
+ images = client.predict(
66
+ image_np=handle_file(filename),
67
+ prompt=prompt,
68
+ api_name="/predict")
69
+ print(images)
70
+ # Open the image
71
+ webp_image = Image.open(images[1])
72
+ # Convert to RGB mode if it's not already
73
+ if webp_image.mode != 'RGB':
74
+ webp_image = webp_image.convert('RGB')
75
+ # Create a new PIL Image object
76
+ pil_image = Image.new('RGB', webp_image.size)
77
+ pil_image.paste(webp_image)
78
+
79
+ print(pil_image)
80
+ print(type(pil_image))
81
+
82
+ return pil_image
83
+
84
+ @spaces.GPU(duration=150)
85
+ def process(
86
+ input_image_editor: dict,
87
+ input_text: str,
88
+ inpaint_text: str,
89
+ seed_slicer: int,
90
+ randomize_seed_checkbox: bool,
91
+ strength_slider: float,
92
+ num_inference_steps_slider: int,
93
+ progress=gr.Progress(track_tqdm=True)
94
+ ):
95
+ if not input_text:
96
+ gr.Info("Please enter a text prompt.")
97
+ return None
98
+
99
+ image = input_image_editor['background']
100
+ #mask = input_image_editor['layers'][0]
101
+ print(f"type of image: {type(image)}")
102
+ mask = evf_sam_mask(image, input_text)
103
+ print(f"type of mask: {type(mask)}")
104
+ print(f"inpaint_text: {inpaint_text}")
105
+ print(f"input_text: {input_text}")
106
+
107
+ if not image:
108
+ gr.Info("Please upload an image.")
109
+ return None
110
+
111
+ if not mask:
112
+ gr.Info("Please draw a mask on the image.")
113
+ return None
114
+
115
+ width, height = resize_image_dimensions(original_resolution_wh=image.size)
116
+ resized_image = image.resize((width, height), Image.LANCZOS)
117
+ resized_mask = mask.resize((width, height), Image.NEAREST)
118
+
119
+ if randomize_seed_checkbox:
120
+ seed_slicer = random.randint(0, MAX_SEED)
121
+ generator = torch.Generator().manual_seed(seed_slicer)
122
+ result = pipe(
123
+ prompt=inpaint_text,
124
+ image=resized_image,
125
+ mask_image=resized_mask,
126
+ width=width,
127
+ height=height,
128
+ strength=strength_slider,
129
+ generator=generator,
130
+ num_inference_steps=num_inference_steps_slider
131
+ ).images[0]
132
+ print('INFERENCE DONE')
133
+ return result, resized_mask
134
+
135
+
136
+ with gr.Blocks() as demo:
137
+ gr.Markdown(MARKDOWN)
138
+ with gr.Row():
139
+ with gr.Column():
140
+ input_image_editor_component = gr.ImageEditor(
141
+ label='Image',
142
+ type='pil',
143
+ sources=["upload", "webcam"],
144
+ image_mode='RGB',
145
+ layers=False,
146
+ brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
147
+
148
+ with gr.Row():
149
+ with gr.Column():
150
+ input_text_component = gr.Text(
151
+ label="Segment",
152
+ show_label=False,
153
+ max_lines=1,
154
+ placeholder="segmentation text",
155
+ container=False,
156
+ )
157
+ inpaint_text_component = gr.Text(
158
+ label="Inpaint",
159
+ show_label=False,
160
+ max_lines=1,
161
+ placeholder="Inpaint text",
162
+ container=False,
163
+ )
164
+ submit_button_component = gr.Button(value='Submit', variant='primary', scale=0)
165
+
166
+ with gr.Accordion("Advanced Settings", open=False):
167
+ seed_slicer_component = gr.Slider(
168
+ label="Seed",
169
+ minimum=0,
170
+ maximum=MAX_SEED,
171
+ step=1,
172
+ value=42,
173
+ )
174
+
175
+ randomize_seed_checkbox_component = gr.Checkbox(
176
+ label="Randomize seed", value=False)
177
+
178
+ with gr.Row():
179
+ strength_slider_component = gr.Slider(
180
+ label="Strength",
181
+ minimum=0,
182
+ maximum=1,
183
+ step=0.01,
184
+ value=0.75,
185
+ )
186
+
187
+ num_inference_steps_slider_component = gr.Slider(
188
+ label="Number of inference steps",
189
+ minimum=1,
190
+ maximum=50,
191
+ step=1,
192
+ value=20,
193
+ )
194
+ with gr.Column():
195
+ output_image_component = gr.Image(
196
+ type='pil', image_mode='RGB', label='Generated image')
197
+ with gr.Accordion("Generated Mask", open=False):
198
+ output_mask_component = gr.Image(
199
+ type='pil', image_mode='RGB', label='Input mask')
200
+
201
+ submit_button_component.click(
202
+ fn=process,
203
+ inputs=[
204
+ input_image_editor_component,
205
+ input_text_component,
206
+ inpaint_text_component,
207
+ seed_slicer_component,
208
+ randomize_seed_checkbox_component,
209
+ strength_slider_component,
210
+ num_inference_steps_slider_component
211
+ ],
212
+ outputs=[
213
+ output_image_component,
214
+ output_mask_component,
215
+ ]
216
+ )
217
+
218
+ demo.launch(debug=True)
219
+