import subprocess import shlex subprocess.run( shlex.split( "pip install ./gradio_magicquill-0.0.1-py3-none-any.whl" ) ) import gradio as gr from gradio_magicquill import MagicQuill import random import torch import numpy as np from PIL import Image, ImageOps import base64 import io from fastapi import FastAPI, Request import uvicorn from MagicQuill import folder_paths from MagicQuill.scribble_color_edit import ScribbleColorEditModel from gradio_client import Client, handle_file from huggingface_hub import snapshot_download import tempfile import cv2 import os import requests snapshot_download(repo_id="LiuZichen/MagicQuill-models", repo_type="model", local_dir="models") HF_TOKEN = os.environ.get("HF_TOKEN") client = Client("LiuZichen/DrawNGuess", hf_token=HF_TOKEN) scribbleColorEditModel = ScribbleColorEditModel() def tensor_to_numpy(tensor): if isinstance(tensor, torch.Tensor): return (tensor.detach().cpu().numpy() * 255).astype(np.uint8) return tensor def tensor_to_base64(tensor): tensor = tensor.squeeze(0) * 255. pil_image = Image.fromarray(tensor.cpu().byte().numpy()) buffered = io.BytesIO() pil_image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return img_str def read_base64_image(base64_image): if base64_image.startswith("data:image/png;base64,"): base64_image = base64_image.split(",")[1] elif base64_image.startswith("data:image/jpeg;base64,"): base64_image = base64_image.split(",")[1] elif base64_image.startswith("data:image/webp;base64,"): base64_image = base64_image.split(",")[1] else: raise ValueError("Unsupported image format.") image_data = base64.b64decode(base64_image) image = Image.open(io.BytesIO(image_data)) image = ImageOps.exif_transpose(image) return image def create_alpha_mask(base64_image): """Create an alpha mask from the alpha channel of an image.""" image = read_base64_image(base64_image) mask = torch.zeros((1, image.height, image.width), dtype=torch.float32, device="cpu") if 'A' in image.getbands(): alpha_channel = np.array(image.getchannel('A')).astype(np.float32) / 255.0 mask[0] = 1.0 - torch.from_numpy(alpha_channel) return mask def load_and_preprocess_image(base64_image, convert_to='RGB', has_alpha=False): """Load and preprocess a base64 image.""" image = read_base64_image(base64_image) image = image.convert(convert_to) image_array = np.array(image).astype(np.float32) / 255.0 image_tensor = torch.from_numpy(image_array)[None,] return image_tensor def load_and_resize_image(base64_image, convert_to='RGB', max_size=512): """Load and preprocess a base64 image, resize if necessary.""" image = read_base64_image(base64_image) image = image.convert(convert_to) width, height = image.size # if min(width, height) > max_size: scaling_factor = max_size / min(width, height) new_size = (int(width * scaling_factor), int(height * scaling_factor)) image = image.resize(new_size, Image.LANCZOS) image_array = np.array(image).astype(np.float32) / 255.0 image_tensor = torch.from_numpy(image_array)[None,] return image_tensor def prepare_images_and_masks(total_mask, original_image, add_color_image, add_edge_image, remove_edge_image): total_mask = create_alpha_mask(total_mask) original_image_tensor = load_and_preprocess_image(original_image) if add_color_image: add_color_image_tensor = load_and_preprocess_image(add_color_image) else: add_color_image_tensor = original_image_tensor add_edge_mask = create_alpha_mask(add_edge_image) if add_edge_image else torch.zeros_like(total_mask) remove_edge_mask = create_alpha_mask(remove_edge_image) if remove_edge_image else torch.zeros_like(total_mask) return add_color_image_tensor, original_image_tensor, total_mask, add_edge_mask, remove_edge_mask def guess_prompt_handler(original_image, add_color_image, add_edge_image): # 将张量转换为 NumPy 数组 original_image_tensor = load_and_preprocess_image(original_image) if add_color_image: add_color_image_tensor = load_and_preprocess_image(add_color_image) else: add_color_image_tensor = original_image_tensor width, height = original_image_tensor.shape[1], original_image_tensor.shape[2] add_edge_mask = create_alpha_mask(add_edge_image) if add_edge_image else torch.zeros((1, height, width), dtype=torch.float32, device="cpu") original_image_numpy = tensor_to_numpy(original_image_tensor.squeeze(0)) add_color_image_numpy = tensor_to_numpy(add_color_image_tensor.squeeze(0)) add_edge_mask_numpy = tensor_to_numpy(add_edge_mask.squeeze(0).unsqueeze(-1)) original_image_numpy = cv2.cvtColor(original_image_numpy, cv2.COLOR_RGB2BGR) add_color_image_numpy = cv2.cvtColor(add_color_image_numpy, cv2.COLOR_RGB2BGR) # 创建临时文件,保存 NumPy 数组为图像文件 original_image_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png", mode='w+b') add_color_image_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png", mode='w+b') add_edge_mask_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png", mode='w+b') # 保存 NumPy 数组为 PNG 图像文件 cv2.imwrite(original_image_file.name, original_image_numpy) cv2.imwrite(add_color_image_file.name, add_color_image_numpy) cv2.imwrite(add_edge_mask_file.name, add_edge_mask_numpy) # 确保文件关闭以保证内容写入磁盘 original_image_file.close() add_color_image_file.close() add_edge_mask_file.close() # 调用 API,传递临时文件的路径 res = client.predict( handle_file(original_image_file.name), handle_file(add_color_image_file.name), handle_file(add_edge_mask_file.name) ) # 删除临时文件,确保它们存在且已被创建 if original_image_file and os.path.exists(original_image_file.name): os.remove(original_image_file.name) if add_color_image_file and os.path.exists(add_color_image_file.name): os.remove(add_color_image_file.name) if add_edge_mask_file and os.path.exists(add_edge_mask_file.name): os.remove(add_edge_mask_file.name) return res def generate(ckpt_name, total_mask, original_image, add_color_image, add_edge_image, remove_edge_image, positive_prompt, negative_prompt, grow_size, stroke_as_edge, fine_edge, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler): add_color_image, original_image, total_mask, add_edge_mask, remove_edge_mask = prepare_images_and_masks(total_mask, original_image, add_color_image, add_edge_image, remove_edge_image) progress = None if fine_edge == 'disable': if torch.sum(remove_edge_mask).item() > 0 and torch.sum(add_edge_mask).item() == 0: if positive_prompt == "": positive_prompt = "empty scene" edge_strength /= 3. latent_samples, final_image, lineart_output, color_output = scribbleColorEditModel.process( ckpt_name, original_image, add_color_image, positive_prompt, negative_prompt, total_mask, add_edge_mask, remove_edge_mask, grow_size, stroke_as_edge, fine_edge, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler, progress ) final_image_base64 = tensor_to_base64(final_image) return final_image_base64 def generate_image_handler(x, ckpt_name, negative_prompt, fine_edge, grow_size, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler): if seed == -1: seed = random.randint(0, 2**32 - 1) ms_data = x['from_frontend'] positive_prompt = x['from_backend']['prompt'] stroke_as_edge = "enable" res = generate(ckpt_name, ms_data['total_mask'], ms_data['original_image'], ms_data['add_color_image'], ms_data['add_edge_image'], ms_data['remove_edge_image'], positive_prompt, negative_prompt, grow_size, stroke_as_edge, fine_edge, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler) x["from_backend"]["generated_image"] = res return x css = ''' .row { width: 90%; margin: auto; } footer { visibility: hidden } ''' with gr.Blocks(css=css) as demo: with gr.Row(elem_classes="row"): ms = MagicQuill() with gr.Row(elem_classes="row"): with gr.Column(): btn = gr.Button("Run", variant="primary") with gr.Column(): with gr.Accordion("parameters", open=False): ckpt_name = gr.Dropdown( label="Base Model Name", choices=folder_paths.get_filename_list("checkpoints"), value='SD1.5/realisticVisionV60B1_v51VAE.safetensors', interactive=True ) negative_prompt = gr.Textbox( label="Negative Prompt", value="", interactive=True ) # stroke_as_edge = gr.Radio( # label="Stroke as Edge", # choices=['enable', 'disable'], # value='enable', # interactive=True # ) fine_edge = gr.Radio( label="Fine Edge", choices=['enable', 'disable'], value='disable', interactive=True ) grow_size = gr.Slider( label="Grow Size", minimum=0, maximum=100, value=15, step=1, interactive=True ) edge_strength = gr.Slider( label="Edge Strength", minimum=0.0, maximum=5.0, value=0.6, step=0.01, interactive=True ) color_strength = gr.Slider( label="Color Strength", minimum=0.0, maximum=5.0, value=0.6, step=0.01, interactive=True ) inpaint_strength = gr.Slider( label="Inpaint Strength", minimum=0.0, maximum=5.0, value=1.0, step=0.01, interactive=True ) seed = gr.Number( label="Seed", value=-1, precision=0, interactive=True ) steps = gr.Slider( label="Steps", minimum=1, maximum=50, value=20, step=1, interactive=True ) cfg = gr.Slider( label="CFG", minimum=0.0, maximum=20.0, value=5.0, step=0.1, interactive=True ) sampler_name = gr.Dropdown( label="Sampler Name", choices=["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "ddim", "uni_pc", "uni_pc_bh2"], value='euler_ancestral', interactive=True ) scheduler = gr.Dropdown( label="Scheduler", choices=["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"], value='karras', interactive=True ) btn.click(generate_image_handler, inputs=[ms, ckpt_name, negative_prompt, fine_edge, grow_size, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler], outputs=ms, concurrency_limit=1) demo.queue(max_size=20, status_update_rate=0.1) app = FastAPI() @app.post("/magic_quill/guess_prompt") async def guess_prompt(request: Request): data = await request.json() res = guess_prompt_handler(data['original_image'], data['add_color_image'], data['add_edge_image']) return res @app.post("/magic_quill/process_background_img") async def process_background_img(request: Request): img = await request.json() resized_img_tensor = load_and_resize_image(img) resized_img_base64 = "data:image/png;base64," + tensor_to_base64(resized_img_tensor) # add more processing here return resized_img_base64 app = gr.mount_gradio_app(app, demo, "/") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860) # demo.launch()