import gradio as gr # pyright: ignore[reportMissingTypeStubs] import pillow_heif # pyright: ignore[reportMissingTypeStubs] import spaces # pyright: ignore[reportMissingTypeStubs] import torch from PIL import Image from refiners.fluxion.utils import manual_seed, no_grad from utils import LightingPreference, load_ic_light, resize_modulo_8 pillow_heif.register_heif_opener() # pyright: ignore[reportUnknownMemberType] pillow_heif.register_avif_opener() # pyright: ignore[reportUnknownMemberType] TITLE = """ # IC-Light with Refiners """ # initialize the enhancer, on the cpu DEVICE_CPU = torch.device("cpu") DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 ic_light = load_ic_light(device=DEVICE_CPU, dtype=DTYPE) # "move" the enhancer to the gpu, this is handled/intercepted by Zero GPU DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") ic_light.to(device=DEVICE, dtype=DTYPE) ic_light.device = DEVICE ic_light.dtype = DTYPE ic_light.solver = ic_light.solver.to(device=DEVICE, dtype=DTYPE) @spaces.GPU @no_grad() def process( image: Image.Image, light_pref: str, prompt: str, negative_prompt: str, strength_first_pass: float, strength_second_pass: float, condition_scale: float, num_inference_steps: int, seed: int, ) -> Image.Image: assert image.mode == "RGBA" assert 0 <= strength_second_pass <= 1 assert 0 <= strength_first_pass <= 1 assert num_inference_steps > 0 assert seed >= 0 # set the seed manual_seed(seed) # resize image to ~768x768 image = resize_modulo_8(image, 768) # split RGB and alpha channel mask = image.getchannel("A") image = image.convert("RGB") # compute embeddings clip_text_embedding = ic_light.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) ic_light.set_ic_light_condition(image=image, mask=mask) # get the light_pref_image light_pref_image = LightingPreference.from_str(value=light_pref).get_init_image( width=image.width, height=image.height, interval=(0.2, 0.8), ) # if no light preference is provided, do a full strength first pass if light_pref_image is None: x = torch.randn_like(ic_light._ic_light_condition) # pyright: ignore[reportPrivateUsage] strength_first_pass = 1.0 else: x = ic_light.lda.image_to_latents(light_pref_image) x = ic_light.solver.add_noise(x, noise=torch.randn_like(x), step=0) # configure the first pass num_steps = int(round(num_inference_steps / strength_first_pass)) first_step = int(num_steps * (1 - strength_first_pass)) ic_light.set_inference_steps(num_steps, first_step) # first pass for step in ic_light.steps: x = ic_light( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=condition_scale, ) # configure the second pass num_steps = int(round(num_inference_steps / strength_second_pass)) first_step = int(num_steps * (1 - strength_second_pass)) ic_light.set_inference_steps(num_steps, first_step) # initialize the latents x = ic_light.solver.add_noise(x, noise=torch.randn_like(x), step=first_step) # second pass for step in ic_light.steps: x = ic_light( x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=condition_scale, ) return ic_light.lda.latents_to_image(x) with gr.Blocks() as demo: gr.Markdown(TITLE) with gr.Row(): with gr.Column(): input_image = gr.Image( label="Input Image (RGBA)", image_mode="RGBA", type="pil", ) run_button = gr.Button( value="Relight Image", ) with gr.Column(): output_image = gr.Image( label="Relighted Image (RGB)", image_mode="RGB", type="pil", ) with gr.Accordion("Advanced Settings", open=True): prompt = gr.Textbox( label="Prompt", placeholder="bright green neon light, best quality, highres", ) neg_prompt = gr.Textbox( label="Negative Prompt", placeholder="worst quality, low quality, normal quality", ) light_pref = gr.Radio( choices=["None", "Left", "Right", "Top", "Bottom"], label="Light direction preference", value="None", ) seed = gr.Slider( label="Seed", minimum=0, maximum=100_000, value=69_420, step=1, ) condition_scale = gr.Slider( label="Condition scale", minimum=0.5, maximum=2, value=1.25, step=0.05, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=50, value=25, step=1, ) with gr.Row(): strength_first_pass = gr.Slider( label="Strength of the first pass", minimum=0, maximum=1, value=0.9, step=0.1, ) strength_second_pass = gr.Slider( label="Strength of the second pass", minimum=0, maximum=1, value=0.5, step=0.1, ) run_button.click( fn=process, inputs=[ input_image, light_pref, prompt, neg_prompt, strength_first_pass, strength_second_pass, condition_scale, num_inference_steps, seed, ], outputs=output_image, ) gr.Examples( # pyright: ignore[reportUnknownMemberType] examples=[ [ "examples/plant.png", "None", "blue purple neon light, cyberpunk city background, high-quality professional studo photography, realistic soft lighting, HEIC, CR2, NEF", "dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white", 0.9, 0.5, 1.25, 25, 69_420, ], [ "examples/plant.png", "Right", "blue purple neon light, cyberpunk city background, high-quality professional studo photography, realistic soft lighting, HEIC, CR2, NEF", "dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white", 0.9, 0.5, 1.25, 25, 69_420, ], [ "examples/plant.png", "Left", "floor is blue ice cavern, stalactite, high-quality professional studo photography, realistic soft lighting, HEIC, CR2, NEF", "dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white", 0.9, 0.5, 1.25, 25, 69_420, ], [ "examples/chair.png", "Right", "god rays, fluffy clouds, peaceful surreal atmosphere, high-quality, HEIC, CR2, NEF", "dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white", 0.9, 0.5, 1.25, 25, 69, ], [ "examples/bunny.png", "Left", "grass field, high-quality, HEIC, CR2, NEF", "dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white", 0.9, 0.5, 1.25, 25, 420, ], ], inputs=[ input_image, light_pref, prompt, neg_prompt, strength_first_pass, strength_second_pass, condition_scale, num_inference_steps, seed, ], outputs=output_image, fn=process, cache_examples=True, cache_mode="lazy", run_on_click=False, ) demo.launch()