|
import gradio as gr |
|
import pillow_heif |
|
import spaces |
|
import torch |
|
from PIL import Image |
|
from refiners.fluxion.utils import manual_seed, no_grad |
|
|
|
from utils import LightingPreference, load_ic_light, resize_modulo_8 |
|
|
|
pillow_heif.register_heif_opener() |
|
pillow_heif.register_avif_opener() |
|
|
|
TITLE = """ |
|
# IC-Light with Refiners |
|
""" |
|
|
|
|
|
DEVICE_CPU = torch.device("cpu") |
|
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 |
|
ic_light = load_ic_light(device=DEVICE_CPU, dtype=DTYPE) |
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
ic_light.to(device=DEVICE, dtype=DTYPE) |
|
ic_light.device = DEVICE |
|
ic_light.dtype = DTYPE |
|
ic_light.solver = ic_light.solver.to(device=DEVICE, dtype=DTYPE) |
|
|
|
|
|
@spaces.GPU |
|
@no_grad() |
|
def process( |
|
image: Image.Image, |
|
light_pref: str, |
|
prompt: str, |
|
negative_prompt: str, |
|
strength_first_pass: float, |
|
strength_second_pass: float, |
|
condition_scale: float, |
|
num_inference_steps: int, |
|
seed: int, |
|
) -> Image.Image: |
|
assert image.mode == "RGBA" |
|
assert 0 <= strength_second_pass <= 1 |
|
assert 0 <= strength_first_pass <= 1 |
|
assert num_inference_steps > 0 |
|
assert seed >= 0 |
|
|
|
|
|
manual_seed(seed) |
|
|
|
|
|
image = resize_modulo_8(image, 768) |
|
|
|
|
|
mask = image.getchannel("A") |
|
image = image.convert("RGB") |
|
|
|
|
|
clip_text_embedding = ic_light.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) |
|
ic_light.set_ic_light_condition(image=image, mask=mask) |
|
|
|
|
|
light_pref_image = LightingPreference.from_str(value=light_pref).get_init_image( |
|
width=image.width, |
|
height=image.height, |
|
interval=(0.2, 0.8), |
|
) |
|
|
|
|
|
if light_pref_image is None: |
|
x = torch.randn_like(ic_light._ic_light_condition) |
|
strength_first_pass = 1.0 |
|
else: |
|
x = ic_light.lda.image_to_latents(light_pref_image) |
|
x = ic_light.solver.add_noise(x, noise=torch.randn_like(x), step=0) |
|
|
|
|
|
num_steps = int(round(num_inference_steps / strength_first_pass)) |
|
first_step = int(num_steps * (1 - strength_first_pass)) |
|
ic_light.set_inference_steps(num_steps, first_step) |
|
|
|
|
|
for step in ic_light.steps: |
|
x = ic_light( |
|
x, |
|
step=step, |
|
clip_text_embedding=clip_text_embedding, |
|
condition_scale=condition_scale, |
|
) |
|
|
|
|
|
num_steps = int(round(num_inference_steps / strength_second_pass)) |
|
first_step = int(num_steps * (1 - strength_second_pass)) |
|
ic_light.set_inference_steps(num_steps, first_step) |
|
|
|
|
|
x = ic_light.solver.add_noise(x, noise=torch.randn_like(x), step=first_step) |
|
|
|
|
|
for step in ic_light.steps: |
|
x = ic_light( |
|
x, |
|
step=step, |
|
clip_text_embedding=clip_text_embedding, |
|
condition_scale=condition_scale, |
|
) |
|
|
|
return ic_light.lda.latents_to_image(x) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(TITLE) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image( |
|
label="Input Image (RGBA)", |
|
image_mode="RGBA", |
|
type="pil", |
|
) |
|
run_button = gr.Button( |
|
value="Relight Image", |
|
) |
|
with gr.Column(): |
|
output_image = gr.Image( |
|
label="Relighted Image (RGB)", |
|
image_mode="RGB", |
|
type="pil", |
|
) |
|
|
|
with gr.Accordion("Advanced Settings", open=True): |
|
prompt = gr.Textbox( |
|
label="Prompt", |
|
placeholder="bright green neon light, best quality, highres", |
|
) |
|
neg_prompt = gr.Textbox( |
|
label="Negative Prompt", |
|
placeholder="worst quality, low quality, normal quality", |
|
) |
|
light_pref = gr.Radio( |
|
choices=["None", "Left", "Right", "Top", "Bottom"], |
|
label="Light direction preference", |
|
value="None", |
|
) |
|
seed = gr.Slider( |
|
label="Seed", |
|
minimum=0, |
|
maximum=100_000, |
|
value=69_420, |
|
step=1, |
|
) |
|
condition_scale = gr.Slider( |
|
label="Condition scale", |
|
minimum=0.5, |
|
maximum=2, |
|
value=1.25, |
|
step=0.05, |
|
) |
|
num_inference_steps = gr.Slider( |
|
label="Number of inference steps", |
|
minimum=1, |
|
maximum=50, |
|
value=25, |
|
step=1, |
|
) |
|
with gr.Row(): |
|
strength_first_pass = gr.Slider( |
|
label="Strength of the first pass", |
|
minimum=0, |
|
maximum=1, |
|
value=0.9, |
|
step=0.1, |
|
) |
|
strength_second_pass = gr.Slider( |
|
label="Strength of the second pass", |
|
minimum=0, |
|
maximum=1, |
|
value=0.5, |
|
step=0.1, |
|
) |
|
|
|
run_button.click( |
|
fn=process, |
|
inputs=[ |
|
input_image, |
|
light_pref, |
|
prompt, |
|
neg_prompt, |
|
strength_first_pass, |
|
strength_second_pass, |
|
condition_scale, |
|
num_inference_steps, |
|
seed, |
|
], |
|
outputs=output_image, |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
[ |
|
"examples/plant.png", |
|
"None", |
|
"blue purple neon light, cyberpunk city background, high-quality professional studo photography, realistic soft lighting, HEIC, CR2, NEF", |
|
"dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white", |
|
0.9, |
|
0.5, |
|
1.25, |
|
25, |
|
69_420, |
|
], |
|
[ |
|
"examples/plant.png", |
|
"Right", |
|
"blue purple neon light, cyberpunk city background, high-quality professional studo photography, realistic soft lighting, HEIC, CR2, NEF", |
|
"dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white", |
|
0.9, |
|
0.5, |
|
1.25, |
|
25, |
|
69_420, |
|
], |
|
[ |
|
"examples/plant.png", |
|
"Left", |
|
"floor is blue ice cavern, stalactite, high-quality professional studo photography, realistic soft lighting, HEIC, CR2, NEF", |
|
"dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white", |
|
0.9, |
|
0.5, |
|
1.25, |
|
25, |
|
69_420, |
|
], |
|
[ |
|
"examples/chair.png", |
|
"Right", |
|
"god rays, fluffy clouds, peaceful surreal atmosphere, high-quality, HEIC, CR2, NEF", |
|
"dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white", |
|
0.9, |
|
0.5, |
|
1.25, |
|
25, |
|
69, |
|
], |
|
[ |
|
"examples/bunny.png", |
|
"Left", |
|
"grass field, high-quality, HEIC, CR2, NEF", |
|
"dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white", |
|
0.9, |
|
0.5, |
|
1.25, |
|
25, |
|
420, |
|
], |
|
], |
|
inputs=[ |
|
input_image, |
|
light_pref, |
|
prompt, |
|
neg_prompt, |
|
strength_first_pass, |
|
strength_second_pass, |
|
condition_scale, |
|
num_inference_steps, |
|
seed, |
|
], |
|
outputs=output_image, |
|
fn=process, |
|
cache_examples=True, |
|
cache_mode="lazy", |
|
run_on_click=False, |
|
) |
|
|
|
demo.launch() |
|
|