import gradio as gr from PIL import Image, ImageDraw from inference import generate_image TASK_TO_INDEX = {"Task 1": 0, "Task 2": 1, "Task 3": 2, "Task 4": 3} def create_marker_overlay(image_path: str, x: int, y: int) -> Image.Image: """ Creates an image with a marker at the specified coordinates """ base_image = Image.open(image_path) marked_image = base_image.copy() draw = ImageDraw.Draw(marked_image) marker_size = 10 marker_color = "red" draw.line([x - marker_size, y, x + marker_size, y], fill=marker_color, width=2) draw.line([x, y - marker_size, x, y + marker_size], fill=marker_color, width=2) return marked_image def update_reference_image(choice: int) -> tuple[str, int, str]: """ Update the reference image display based on radio button selection Returns the image path, selected index, and corresponding heatmap """ image_path = f"imgs/pattern_{choice}.png" heatmap_path = f"imgs/heatmap_{choice}.png" return image_path, choice, heatmap_path def update_marker(image_idx: int, evt: gr.SelectData) -> tuple[Image.Image, tuple[int, int]]: """ Update the coordinate selector with the marker Returns the marked image and the coordinates for the next function """ x, y = evt.index[0], evt.index[1] heatmap_path = f"imgs/heatmap_{image_idx}.png" return create_marker_overlay(heatmap_path, x, y), (x, y) def generate_output_image(image_idx: int, coords: tuple[int, int]) -> Image.Image: """ Generate the output image based on the selected coordinates """ x, y = coords x_norm, y_norm = x / 1155, y / 1155 return generate_image(image_idx, x_norm, y_norm) with gr.Blocks( css=""" .radio-container { width: 450px !important; margin-left: auto !important; margin-right: auto !important; } .coordinate-container { width: 600px !important; height: 600px !important; } .coordinate-container img { width: 100% !important; height: 100% !important; object-fit: contain !important; } .documentation { margin-top: 2rem !important; padding: 1rem !important; background-color: #f8f9fa !important; border-radius: 8px !important; } """ ) as demo: gr.Markdown( """ # Interactive Image Generation Select a task using the radio buttons, then click on the coordinate selector to generate a new image. """ ) with gr.Row(): # Left column with gr.Column(scale=1): selected_idx = gr.State(value=0) coords = gr.State() # Add state for coordinates with gr.Column(elem_classes="radio-container"): task_select = gr.Radio( choices=["Task 1", "Task 2", "Task 3", "Task 4"], value="Task 1", label="Select Task", interactive=True, ) gr.Markdown("### Reference Pattern") reference_image = gr.Image( value="imgs/pattern_0.png", show_label=False, interactive=False, height=300, width=450, show_download_button=False, show_fullscreen_button=False, ) gr.Markdown("### Generated Output") output_image = gr.Image( show_label=False, height=300, width=450, show_download_button=False, show_fullscreen_button=False, interactive=False, ) # Right column with gr.Column(scale=1): gr.Markdown("### Coordinate Selector") gr.Markdown("Click anywhere in the image below to select (x, y) coordinates in the latent space") with gr.Column(elem_classes="coordinate-container"): coord_selector = gr.Image( value="imgs/heatmap_0.png", show_label=False, interactive=False, sources=[], container=True, show_download_button=False, show_fullscreen_button=False, ) # Documentation section with gr.Column(elem_classes="documentation"): gr.Markdown( """ ## Method Documentation ### How It Works This interactive demo showcases our novel image generation method that uses coordinate-based control. The process works as follows: 1. **Task Selection**: Choose one of four different pattern generation tasks 2. **Reference Pattern**: View the target pattern for the selected task 3. **Coordinate Selection**: Click anywhere in the heatmap to specify where in the latent space you want to generate from 4. **Generation**: The model generates a new image based on your selected coordinates ### Sample Results Here are some example outputs from our method: ![LPN Diagram](imgs/lpn_diagram.png) ### Technical Details Our approach uses a novel coordinate-conditioning mechanism that allows precise control over the generated patterns. The heatmap visualization shows the distribution of pattern characteristics across the latent space. For more information, please refer to our [paper](https://arxiv.org/pdf/2411.08706) or GitHub [repository](https://github.com/clement-bonnet/lpn). """ ) # Event handlers task_select.change( fn=lambda x: update_reference_image(TASK_TO_INDEX[x]), inputs=[task_select], outputs=[reference_image, selected_idx, coord_selector], ) # Split the coordinate selection into two events with state passing coord_selector.select( fn=update_marker, inputs=[selected_idx], outputs=[coord_selector, coords], trigger_mode="multiple", ).then( fn=generate_output_image, inputs=[selected_idx, coords], outputs=output_image, ) demo.launch()