gaur3009 commited on
Commit
d8988f8
·
verified ·
1 Parent(s): 0e15ea7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -106
app.py CHANGED
@@ -1,124 +1,96 @@
1
- from __future__ import annotations
2
- import math
3
- import random
4
  import gradio as gr
 
 
5
  import numpy as np
 
6
  import torch
7
- from PIL import Image
8
- from diffusers import StableDiffusionXLImg2ImgPipeline, EDMEulerScheduler, AutoencoderKL
9
- from huggingface_hub import hf_hub_download
10
 
11
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
12
 
13
- pipe_edit = StableDiffusionXLImg2ImgPipeline.from_single_file(
14
- hf_hub_download(repo_id="stabilityai/cosxl", filename="cosxl_edit.safetensors"),
15
- num_in_channels=8,
16
- is_cosxl_edit=True,
17
- vae=vae,
18
- torch_dtype=torch.float16,
19
  )
20
-
21
- pipe_edit.scheduler = EDMEulerScheduler(sigma_min=0.002, sigma_max=120.0, sigma_data=1.0, prediction_type="v_prediction")
22
- pipe_edit.to("cuda")
23
-
24
- refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
25
- "stabilityai/stable-diffusion-xl-refiner-1.0",
26
- vae=vae,
27
- torch_dtype=torch.float16,
28
- use_safetensors=True,
29
- variant="fp16"
30
  )
31
- refiner.to("cuda")
32
 
33
- def set_timesteps_patched(self, num_inference_steps: int, device=None):
34
- self.num_inference_steps = num_inference_steps
35
- ramp = np.linspace(0, 1, self.num_inference_steps)
36
- sigmas = torch.linspace(math.log(self.config.sigma_min), math.log(self.config.sigma_max), len(ramp)).exp().flip(0)
37
- sigmas = sigmas.to(dtype=torch.float32, device=device)
38
- self.timesteps = self.precondition_noise(sigmas)
39
- self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
40
- self._step_index = None
41
- self._begin_index = None
42
- self.sigmas = self.sigmas.to("cpu")
 
 
43
 
44
- EDMEulerScheduler.set_timesteps = set_timesteps_patched
 
 
 
 
45
 
46
- def king(input_image, instruction: str, negative_prompt: str = "", steps: int = 25, randomize_seed: bool = True, seed: int = 2404, guidance_scale: float = 6, progress=gr.Progress(track_tqdm=True)):
47
- input_image = Image.open(input_image).convert('RGB')
48
- if randomize_seed:
49
- seed = random.randint(0, 999999)
50
- generator = torch.manual_seed(seed)
51
- output_image = pipe_edit(
52
- instruction,
53
- negative_prompt=negative_prompt,
54
- image=input_image,
55
- guidance_scale=guidance_scale,
56
- image_guidance_scale=1.5,
57
- width=input_image.width,
58
- height=input_image.height,
59
- num_inference_steps=steps,
60
- generator=generator,
61
- output_type="latent",
62
- ).images
63
- refine = refiner(
64
- prompt=f"{instruction}, 4k, hd, high quality, masterpiece",
65
- negative_prompt=negative_prompt,
66
- guidance_scale=7.5,
67
- num_inference_steps=steps,
68
- image=output_image,
69
- generator=generator,
70
- ).images[0]
71
- return seed, refine
72
 
73
- css = '''
74
- .gradio-container{max-width: 700px !important}
75
- h1{text-align:center}
76
- footer {
77
- visibility: hidden
78
- }
79
- '''
80
 
81
- examples = [
82
- ["./supercar.png", "make it red"],
83
- ["./red_car.png", "add some snow"],
84
- ]
85
 
86
- with gr.Blocks(css=css) as demo:
87
- gr.Markdown("# Image Editing\n### Note: First image generation takes time")
88
- with gr.Row():
89
- instruction = gr.Textbox(lines=1, label="Instruction", interactive=True)
90
- generate_button = gr.Button("Run", scale=0)
91
-
92
- with gr.Row():
93
- input_image = gr.Image(label="Image", type='filepath', interactive=True)
94
 
95
- with gr.Row():
96
- guidance_scale = gr.Number(value=6.0, step=0.1, label="Guidance Scale", interactive=True)
97
- steps = gr.Number(value=25, step=1, label="Steps", interactive=True)
 
 
 
98
 
99
- with gr.Accordion("Advanced options", open=False):
100
- with gr.Row():
101
- negative_prompt = gr.Text(
102
- label="Negative prompt",
103
- max_lines=1,
104
- value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, ugly, disgusting, blurry, amputation,(face asymmetry, eyes asymmetry, deformed eyes, open mouth)",
105
- visible=True
106
- )
107
- with gr.Row():
108
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True, interactive=True)
109
- seed = gr.Number(value=2404, step=1, label="Seed", interactive=True)
110
 
111
- gr.Examples(
112
- examples=examples,
113
- inputs=[input_image, instruction],
114
- outputs=[input_image],
115
- cache_examples=False,
116
- )
117
 
118
- generate_button.click(
119
- king,
120
- inputs=[input_image, instruction, negative_prompt, steps, randomize_seed, seed, guidance_scale],
121
- outputs=[seed, input_image],
122
- )
123
 
124
- demo.queue(max_size=500).launch()
 
 
 
 
 
1
  import gradio as gr
2
+ from gradio_imageslider import ImageSlider
3
+ from PIL import Image, ImageDraw, ImageFont
4
  import numpy as np
5
+ import cv2
6
  import torch
7
+ from torchvision import transforms
8
+ from transformers import AutoModelForImageSegmentation
 
9
 
10
+ torch.set_float32_matmul_precision(["high", "highest"][0])
11
 
12
+ # Load BiRefNet model for background removal
13
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
14
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
 
 
 
15
  )
16
+ birefnet.to("cuda")
17
+ transform_image = transforms.Compose(
18
+ [
19
+ transforms.Resize((1024, 1024)),
20
+ transforms.ToTensor(),
21
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
22
+ ]
 
 
 
23
  )
 
24
 
25
+ def load_img(image, output_type="numpy"):
26
+ if output_type == "pil":
27
+ return Image.open(image).convert("RGB")
28
+ else:
29
+ return np.array(Image.open(image).convert("RGB"))
30
+
31
+ def add_text_to_image(image, text, position, color, font_size):
32
+ img = Image.fromarray(image)
33
+ draw = ImageDraw.Draw(img)
34
+ font = ImageFont.truetype("arial.ttf", font_size)
35
+ draw.text(position, text, fill=color, font=font)
36
+ return np.array(img)
37
 
38
+ def inpaint_image(image, mask, inpaint_radius):
39
+ img_cv = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
40
+ mask_cv = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
41
+ result = cv2.inpaint(img_cv, mask_cv, inpaint_radius, cv2.INPAINT_TELEA)
42
+ return cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
43
 
44
+ def background_removal(image):
45
+ im = load_img(image, output_type="pil")
46
+ im = im.convert("RGB")
47
+ image_size = im.size
48
+ origin = im.copy()
49
+ image = load_img(im)
50
+ input_images = transform_image(image).unsqueeze(0).to("cuda")
51
+ with torch.no_grad():
52
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
53
+ pred = preds[0].squeeze()
54
+ pred_pil = transforms.ToPILImage()(pred)
55
+ mask = pred_pil.resize(image_size)
56
+ im.putalpha(mask)
57
+ return (im, origin)
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ def update_image(image, text, color, font_size, mask_image, inpaint_radius):
60
+ img_with_text = add_text_to_image(image, text, (50, 50), color, font_size)
61
+ if mask_image is not None:
62
+ mask = np.array(mask_image)
63
+ img_with_text = inpaint_image(img_with_text, mask, inpaint_radius)
64
+ return img_with_text
 
65
 
66
+ def fn(image):
67
+ return background_removal(image)
 
 
68
 
69
+ slider1 = ImageSlider(label="Original Image", type="pil")
70
+ slider2 = ImageSlider(label="Processed Image", type="pil")
 
 
 
 
 
 
71
 
72
+ image_input = gr.Image(label="Upload an image for background removal")
73
+ text_input = gr.Textbox(label="Enter Text to Add", placeholder="Your text here...")
74
+ color_input = gr.ColorPicker(label="Text Color")
75
+ font_size_input = gr.Slider(minimum=10, maximum=100, label="Font Size")
76
+ mask_input = gr.Image(type="numpy", label="Upload Mask Image (for Inpainting)", optional=True)
77
+ inpaint_radius_input = gr.Slider(minimum=1, maximum=50, value=3, label="Inpaint Radius")
78
 
79
+ bg_removal_interface = gr.Interface(
80
+ fn, inputs=image_input, outputs=slider1, examples=["chameleon.jpg"]
81
+ )
 
 
 
 
 
 
 
 
82
 
83
+ design_editing_interface = gr.Interface(
84
+ fn=lambda image, text, color, font_size, mask_image, inpaint_radius: update_image(image, text, color, font_size, mask_image, inpaint_radius),
85
+ inputs=[image_input, text_input, color_input, font_size_input, mask_input, inpaint_radius_input],
86
+ outputs=slider2
87
+ )
 
88
 
89
+ demo = gr.TabbedInterface(
90
+ [bg_removal_interface, design_editing_interface],
91
+ ["Background Removal", "Design Editing"],
92
+ title="Advanced Image Editor"
93
+ )
94
 
95
+ if __name__ == "__main__":
96
+ demo.launch()