import importlib from typing import List import gradio as gr import numpy as np import torch from diffusers import StableDiffusionPipeline from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure from image_utils import make_grid, numpy_to_pil from metrics_utils import compute_main_metrics, compute_psnr_or_ssim from report_utils import add_psnr_ssim_to_report, prepare_report SEED = 0 WEIGHT_DTYPE = torch.float16 TITLE = "Evaluate Schedulers with StableDiffusionPipeline ๐Ÿงจ" ABSTRACT = """ This Space allows you to quantitatively compare [different noise schedulers](https://huggingface.co/docs/diffusers/using-diffusers/schedulers) with a [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview). One of the applications of this Space could be to evaluate different schedulers for a certain Stable Diffusion checkpoint for a fixed number of inference steps. """ DESCRIPTION = """ #### Hoes does it work? * The evaluator first sets a seed and then generates the initial noise which is passed as the initial latent to start the image generation process. It is done to ensure fair comparison. * This initial latent is used every time the pipeline is run (with different schedulers). * To quantify the quality of the generated images we use: * [Inception Score](https://en.wikipedia.org/wiki/Inception_score) * [Clip Score](https://arxiv.org/abs/2104.08718) #### Notes * When selecting a model checkpoint, if you select "Other" you will have the option to provide a custom Stable Diffusion checkpoint. * The default scheduler associated with the provided checkpoint is always used for reporting the scores. * Increasing both the number of images per prompt and the number of inference steps could quickly build up the inference queue and thus resulting in slowdowns. """ psnr_fn = PeakSignalNoiseRatio() ssim_fn = StructuralSimilarityIndexMeasure() def initialize_pipeline(checkpoint: str): sd_pipe = StableDiffusionPipeline.from_pretrained( checkpoint, torch_dtype=WEIGHT_DTYPE ) sd_pipe = sd_pipe.to("cuda") original_scheduler_config = sd_pipe.scheduler.config return sd_pipe, original_scheduler_config def get_scheduler(scheduler_name: str): schedulers_lib = importlib.import_module("diffusers", package="schedulers") scheduler_abs = getattr(schedulers_lib, scheduler_name) return scheduler_abs def get_latents(num_images_per_prompt: int, seed=SEED): generator = torch.manual_seed(seed) latents = np.random.RandomState(seed).standard_normal( (num_images_per_prompt, 4, 64, 64) ) latents = torch.from_numpy(latents).to(device="cuda", dtype=WEIGHT_DTYPE) return latents def run( prompt: str, num_images_per_prompt: int, num_inference_steps: int, checkpoint: str, other_finedtuned_checkpoints: str = None, schedulers_to_test: List[str] = None, ssim: bool = False, psnr: bool = False, progress=gr.Progress(), ): progress(0, desc="Starting...") if checkpoint == "Other" and other_finedtuned_checkpoints == "": return "โŒ No legit checkpoint provided โŒ" elif checkpoint == "Other": checkpoint = other_finedtuned_checkpoints all_images = {} scheduler_images = {} # Set up the pipeline sd_pipeline, original_scheduler_config = initialize_pipeline(checkpoint) sd_pipeline.set_progress_bar_config(disable=True) # Prepare latents to start generation and the prompts. latents = get_latents(num_images_per_prompt) prompts = [prompt] * num_images_per_prompt original_scheduler_name = original_scheduler_config._class_name schedulers_to_test.append(original_scheduler_name) # Start generating the images and computing their scores. for scheduler_name in progress.tqdm(schedulers_to_test): if scheduler_name != original_scheduler_name: scheduler_cls = get_scheduler(scheduler_name) current_scheduler = scheduler_cls.from_config(original_scheduler_config) sd_pipeline.scheduler = current_scheduler cur_scheduler_images = sd_pipeline( prompts, latents=latents, num_inference_steps=num_inference_steps, output_type="numpy", ).images all_images.update( { scheduler_name: { "images": make_grid( numpy_to_pil(cur_scheduler_images), 1, num_images_per_prompt ), "scores": compute_main_metrics(cur_scheduler_images, prompts), } } ) scheduler_images.update({scheduler_name: cur_scheduler_images}) torch.cuda.empty_cache() # Prepare output report. output_str = "" for scheduler_name in all_images: output_str += prepare_report(scheduler_name, all_images[scheduler_name]) # Append PSNR or SSIM if needed. if len(schedulers_to_test) > 1: ssim_scores = psnr_scores = None if ssim: ssim_scores = compute_psnr_or_ssim( ssim_fn, scheduler_images, original_scheduler_name ) if psnr: psnr_scores = compute_psnr_or_ssim( psnr_fn, scheduler_images, original_scheduler_name ) if len(schedulers_to_test) > 1: ssim_psnr_str = add_psnr_ssim_to_report( original_scheduler_name, ssim_scores, psnr_scores ) if ssim_psnr_str != "": output_str += ssim_psnr_str return output_str with gr.Blocks(title="Scheduler Evaluation") as demo: gr.Markdown(f"## {TITLE}\n\n\n\n{ABSTRACT}") with gr.Row(): with gr.Column(): prompt = gr.Text( max_lines=1, placeholder="a painting of a dog", label="prompt" ) num_images_per_prompt = gr.Slider( 3, 10, value=3, step=1, label="num_images_per_prompt" ) num_inference_steps = gr.Slider( 10, 100, value=50, step=1, label="num_inference_steps" ) model_ckpt = gr.Dropdown( [ "CompVis/stable-diffusion-v1-4", "runwayml/stable-diffusion-v1-5", "stabilityai/stable-diffusion-2-base", "Other", ], value="CompVis/stable-diffusion-v1-4", multiselect=False, interactive=True, label="model_ckpt", ) other_finedtuned_checkpoints = gr.Textbox( visible=False, interactive=True, placeholder="valhalla/sd-pokemon-model", label="custom_checkpoint", ) model_ckpt.change( lambda x: gr.Dropdown.update(visible=x == "Other"), model_ckpt, other_finedtuned_checkpoints, ) schedulers_to_test = gr.Dropdown( [ "EulerDiscreteScheduler", "PNDMScheduler", "LMSDiscreteScheduler", "DPMSolverMultistepScheduler", "DDIMScheduler", ], value=["LMSDiscreteScheduler"], multiselect=True, label="schedulers_to_test", ) ssim = gr.Checkbox(label="Compute SSIM") psnr = gr.Checkbox(label="Compute PSNR") evaluation_button = gr.Button(value="Submit") with gr.Column(): report = gr.Markdown(label="Evaluation Report").style() evaluation_button.click( run, inputs=[ prompt, num_images_per_prompt, num_inference_steps, model_ckpt, other_finedtuned_checkpoints, schedulers_to_test, ssim, psnr, ], outputs=report, ) gr.Markdown(f"{DESCRIPTION}") demo.queue().launch(debug=True)