ic_light / src /app.py
1aurent's picture
add examples
cb6e7cb unverified
raw
history blame
8.73 kB
import gradio as gr # pyright: ignore[reportMissingTypeStubs]
import pillow_heif # pyright: ignore[reportMissingTypeStubs]
import spaces # pyright: ignore[reportMissingTypeStubs]
import torch
from PIL import Image
from refiners.fluxion.utils import manual_seed, no_grad
from utils import LightingPreference, load_ic_light, resize_modulo_8
pillow_heif.register_heif_opener() # pyright: ignore[reportUnknownMemberType]
pillow_heif.register_avif_opener() # pyright: ignore[reportUnknownMemberType]
TITLE = """
# IC-Light with Refiners
"""
# initialize the enhancer, on the cpu
DEVICE_CPU = torch.device("cpu")
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
ic_light = load_ic_light(device=DEVICE_CPU, dtype=DTYPE)
# "move" the enhancer to the gpu, this is handled/intercepted by Zero GPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ic_light.to(device=DEVICE, dtype=DTYPE)
ic_light.device = DEVICE
ic_light.dtype = DTYPE
ic_light.solver = ic_light.solver.to(device=DEVICE, dtype=DTYPE)
@spaces.GPU
@no_grad()
def process(
image: Image.Image,
light_pref: str,
prompt: str,
negative_prompt: str,
strength_first_pass: float,
strength_second_pass: float,
condition_scale: float,
num_inference_steps: int,
seed: int,
) -> Image.Image:
assert image.mode == "RGBA"
assert 0 <= strength_second_pass <= 1
assert 0 <= strength_first_pass <= 1
assert num_inference_steps > 0
assert seed >= 0
# set the seed
manual_seed(seed)
# resize image to ~768x768
image = resize_modulo_8(image, 768)
# split RGB and alpha channel
mask = image.getchannel("A")
image = image.convert("RGB")
# compute embeddings
clip_text_embedding = ic_light.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
ic_light.set_ic_light_condition(image=image, mask=mask)
# get the light_pref_image
light_pref_image = LightingPreference.from_str(value=light_pref).get_init_image(
width=image.width,
height=image.height,
interval=(0.2, 0.8),
)
# if no light preference is provided, do a full strength first pass
if light_pref_image is None:
x = torch.randn_like(ic_light._ic_light_condition) # pyright: ignore[reportPrivateUsage]
strength_first_pass = 1.0
else:
x = ic_light.lda.image_to_latents(light_pref_image)
x = ic_light.solver.add_noise(x, noise=torch.randn_like(x), step=0)
# configure the first pass
num_steps = int(round(num_inference_steps / strength_first_pass))
first_step = int(num_steps * (1 - strength_first_pass))
ic_light.set_inference_steps(num_steps, first_step)
# first pass
for step in ic_light.steps:
x = ic_light(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=condition_scale,
)
# configure the second pass
num_steps = int(round(num_inference_steps / strength_second_pass))
first_step = int(num_steps * (1 - strength_second_pass))
ic_light.set_inference_steps(num_steps, first_step)
# initialize the latents
x = ic_light.solver.add_noise(x, noise=torch.randn_like(x), step=first_step)
# second pass
for step in ic_light.steps:
x = ic_light(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=condition_scale,
)
return ic_light.lda.latents_to_image(x)
with gr.Blocks() as demo:
gr.Markdown(TITLE)
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Input Image (RGBA)",
image_mode="RGBA",
type="pil",
)
run_button = gr.Button(
value="Relight Image",
)
with gr.Column():
output_image = gr.Image(
label="Relighted Image (RGB)",
image_mode="RGB",
type="pil",
)
with gr.Accordion("Advanced Settings", open=True):
prompt = gr.Textbox(
label="Prompt",
placeholder="bright green neon light, best quality, highres",
)
neg_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="worst quality, low quality, normal quality",
)
light_pref = gr.Radio(
choices=["None", "Left", "Right", "Top", "Bottom"],
label="Light direction preference",
value="None",
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=100_000,
value=69_420,
step=1,
)
condition_scale = gr.Slider(
label="Condition scale",
minimum=0.5,
maximum=2,
value=1.25,
step=0.05,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
value=25,
step=1,
)
with gr.Row():
strength_first_pass = gr.Slider(
label="Strength of the first pass",
minimum=0,
maximum=1,
value=0.9,
step=0.1,
)
strength_second_pass = gr.Slider(
label="Strength of the second pass",
minimum=0,
maximum=1,
value=0.5,
step=0.1,
)
run_button.click(
fn=process,
inputs=[
input_image,
light_pref,
prompt,
neg_prompt,
strength_first_pass,
strength_second_pass,
condition_scale,
num_inference_steps,
seed,
],
outputs=output_image,
)
gr.Examples( # pyright: ignore[reportUnknownMemberType]
examples=[
[
"examples/plant.png",
"None",
"blue purple neon light, cyberpunk city background, high-quality professional studo photography, realistic soft lighting, HEIC, CR2, NEF",
"dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white",
0.9,
0.5,
1.25,
25,
69_420,
],
[
"examples/plant.png",
"Right",
"blue purple neon light, cyberpunk city background, high-quality professional studo photography, realistic soft lighting, HEIC, CR2, NEF",
"dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white",
0.9,
0.5,
1.25,
25,
69_420,
],
[
"examples/plant.png",
"Left",
"floor is blue ice cavern, stalactite, high-quality professional studo photography, realistic soft lighting, HEIC, CR2, NEF",
"dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white",
0.9,
0.5,
1.25,
25,
69_420,
],
[
"examples/chair.png",
"Right",
"god rays, fluffy clouds, peaceful surreal atmosphere, high-quality, HEIC, CR2, NEF",
"dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white",
0.9,
0.5,
1.25,
25,
69,
],
[
"examples/bunny.png",
"Left",
"grass field, high-quality, HEIC, CR2, NEF",
"dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white",
0.9,
0.5,
1.25,
25,
420,
],
],
inputs=[
input_image,
light_pref,
prompt,
neg_prompt,
strength_first_pass,
strength_second_pass,
condition_scale,
num_inference_steps,
seed,
],
outputs=output_image,
fn=process,
cache_examples=True,
cache_mode="lazy",
run_on_click=False,
)
demo.launch()