rorschach / app.py
Michael Ramos
layout
d0af695
from diffusers import AutoPipelineForImage2Image, AutoPipelineForText2Image
import torch
import os
try:
import intel_extension_for_pytorch as ipex
except:
pass
from PIL import Image
import numpy as np
import gradio as gr
import psutil
import time
import math
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# check if MPS is available OSX only M1/M2/M3 chips
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
device = torch.device(
"cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
)
torch_device = device
torch_dtype = torch.float16
print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
print(f"TORCH_COMPILE: {TORCH_COMPILE}")
print(f"device: {device}")
if mps_available:
device = torch.device("mps")
torch_device = "cpu"
torch_dtype = torch.float32
if SAFETY_CHECKER == "True":
i2i_pipe = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype=torch_dtype,
variant="fp16" if torch_dtype == torch.float16 else "fp32",
)
t2i_pipe = AutoPipelineForText2Image.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype=torch_dtype,
variant="fp16" if torch_dtype == torch.float16 else "fp32",
)
else:
i2i_pipe = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/sdxl-turbo",
safety_checker=None,
torch_dtype=torch_dtype,
variant="fp16" if torch_dtype == torch.float16 else "fp32",
)
t2i_pipe = AutoPipelineForText2Image.from_pretrained(
"stabilityai/sdxl-turbo",
safety_checker=None,
torch_dtype=torch_dtype,
variant="fp16" if torch_dtype == torch.float16 else "fp32",
)
t2i_pipe.to(device=torch_device, dtype=torch_dtype).to(device)
t2i_pipe.set_progress_bar_config(disable=True)
i2i_pipe.to(device=torch_device, dtype=torch_dtype).to(device)
i2i_pipe.set_progress_bar_config(disable=True)
def resize_crop(image, size=512):
image = image.convert("RGB")
w, h = image.size
image = image.resize((size, int(size * (h / w))), Image.BICUBIC)
return image
# Global variable to store the selected image index
selected_image_index = None
# Load images from the 'images' folder
image_folder = 'images'
images = [Image.open(os.path.join(image_folder, img)) for img in sorted(os.listdir(image_folder)) if img.endswith(('.png', '.jpg', '.jpeg'))]
# Ensure that there are 34 images
assert len(images) == 34, "There should be exactly 34 images in the 'images' folder."
# Function to handle image selection
async def select_fn(data: gr.SelectData, prompt: str):
global selected_image_index
selected_image_index = data.index
if prompt == "":
print("Prompt is empty, returning original image")
return images[selected_image_index]
return await predict(prompt)
async def predict(prompt):
global selected_image_index
strength = 0.49999999999999999
steps = 2
if selected_image_index is not None:
init_image = images[selected_image_index]
init_image = resize_crop(init_image)
generator = torch.manual_seed(123123)
last_time = time.time()
if int(steps * strength) < 1:
steps = math.ceil(1 / max(0.10, strength))
results = i2i_pipe(
prompt=prompt,
image=init_image,
generator=generator,
num_inference_steps=steps,
guidance_scale=0.0,
strength=strength,
width=512,
height=512,
output_type="pil",
)
print(f"Pipe took {time.time() - last_time} seconds")
nsfw_content_detected = (
results.nsfw_content_detected[0]
if "nsfw_content_detected" in results
else False
)
if nsfw_content_detected:
gr.Warning("NSFW content detected.")
return Image.new("RGB", (512, 512))
return results.images[0]
# Create the Gradio interface
with gr.Blocks() as app:
gr.Markdown('''# Rorschach 🎭
### 1. Select a CRASH REPORT image
### 2. Describe what you see
<small>CRASH REPORT was a self-published, 72-page book by NoPattern Studio released in November, 2019. Limited to an edition of 300, the book contained a year's worth of experimental, exploratory 3D imagery generated entirely in Photoshop. [CRASH REPORT site](https://nopattern.com/CRASH-REPORT) [see this space's lineage graph](https://huggingface.co/spaces/EQTYLab/lineage-explorer?repo=https://huggingface.co/NoPattern/Rorschach)</small>''', elem_id="main_title")
#gr.LoginButton(elem_classes=["login_logout"])
with gr.Row():
with gr.Column():
image_gallery = gr.Gallery(value=images, columns=4)
with gr.Column():
prompt = gr.Textbox(label="I see...")
output = gr.Image(label="Generation")
# button = gr.Button("Rorschachify!")
image_gallery.select(select_fn, inputs=[prompt], outputs=output, show_progress=False)
# button.click(fn=predict, inputs=[prompt], outputs=output, show_progress=False)
prompt.change(fn=predict, inputs=[prompt], outputs=output, show_progress=False)
# Run the app
app.queue()
app.launch()