{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "initial_id", "metadata": { "ExecuteTime": { "end_time": "2024-12-09T09:44:30.641366Z", "start_time": "2024-12-09T09:44:11.789050Z" } }, "outputs": [], "source": [ "import os\n", "\n", "import gradio as gr\n", "from diffusers import DiffusionPipeline\n", "import matplotlib.pyplot as plt\n", "import torch\n", "from PIL import Image\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "ddf33e0d3abacc2c", "metadata": {}, "outputs": [], "source": [ "import sys\n", "#append current path\n", "sys.path.extend(\"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/release/hf_demo\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "643e49fd601daf8f", "metadata": { "ExecuteTime": { "end_time": "2024-12-09T09:44:35.790962Z", "start_time": "2024-12-09T09:44:35.779496Z" } }, "outputs": [], "source": [ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16" ] }, { "cell_type": "code", "execution_count": 4, "id": "e03aae2a4e5676dd", "metadata": { "ExecuteTime": { "end_time": "2024-12-09T09:44:44.157412Z", "start_time": "2024-12-09T09:44:37.138452Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/miniforge3/envs/diffusion/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", " warnings.warn(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "acc42f294243439798e4d77d1a59296d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading pipeline components...: 0%| | 0/7 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "pipe = DiffusionPipeline.from_pretrained(\"rhfeiyang/art-free-diffusion-v1\",\n", " torch_dtype=dtype).to(device)" ] }, { "cell_type": "code", "execution_count": 5, "id": "83916bc68ff5d914", "metadata": { "ExecuteTime": { "end_time": "2024-12-09T09:44:52.694399Z", "start_time": "2024-12-09T09:44:44.210695Z" } }, "outputs": [], "source": [ "from inference import get_lora_network, inference, get_validation_dataloader\n", "lora_map = {\n", " \"None\": \"None\",\n", " \"Andre Derain (fauvism)\": \"andre-derain_subset1\",\n", " \"Vincent van Gogh (post impressionism)\": \"van_gogh_subset1\",\n", " \"Andy Warhol (pop art)\": \"andy_subset1\",\n", " \"Walter Battiss\": \"walter-battiss_subset2\",\n", " \"Camille Corot (realism)\": \"camille-corot_subset1\",\n", " \"Claude Monet (impressionism)\": \"monet_subset2\",\n", " \"Pablo Picasso (cubism)\": \"picasso_subset1\",\n", " \"Jackson Pollock\": \"jackson-pollock_subset1\",\n", " \"Gerhard Richter (abstract expressionism)\": \"gerhard-richter_subset1\",\n", " \"M.C. Escher\": \"m.c.-escher_subset1\",\n", " \"Albert Gleizes\": \"albert-gleizes_subset1\",\n", " \"Hokusai (ukiyo-e)\": \"katsushika-hokusai_subset1\",\n", " \"Wassily Kandinsky\": \"kandinsky_subset1\",\n", " \"Gustav Klimt (art nouveau)\": \"klimt_subset3\",\n", " \"Roy Lichtenstein\": \"roy-lichtenstein_subset1\",\n", " \"Henri Matisse (abstract expressionism)\": \"henri-matisse_subset1\",\n", " \"Joan Miro\": \"joan-miro_subset2\",\n", "}\n", "\n", "\n", "\n", "def demo_inference_gen_artistic(adapter_choice:str, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, adapter_scale=1.0):\n", " adapter_path = lora_map[adapter_choice]\n", " if adapter_path not in [None, \"None\"]:\n", " adapter_path = f\"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n", " style_prompt=\"sks art\"\n", " else:\n", " style_prompt=None\n", " prompts = [prompt]\n", " infer_loader = get_validation_dataloader(prompts,num_workers=0)\n", " network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)[\"network\"]\n", "\n", " pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n", " height=512, width=512, scales=[adapter_scale],\n", " save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n", " start_noise=-1, show=False, style_prompt=style_prompt, no_load=True,\n", " from_scratch=True, device=device, weight_dtype=dtype)[0][1.0][0]\n", " return pred_images\n", "\n", "\n", "def demo_inference_gen_ori( prompt:str, seed:int=0, steps=50, guidance_scale=7.5):\n", " style_prompt=None\n", " prompts = [prompt]\n", " infer_loader = get_validation_dataloader(prompts,num_workers=0)\n", " network = get_lora_network(pipe.unet, \"None\", weight_dtype=dtype)[\"network\"]\n", "\n", " pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n", " height=512, width=512, scales=[0.0],\n", " save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n", " start_noise=-1, show=False, style_prompt=style_prompt, no_load=True,\n", " from_scratch=True, device=device, weight_dtype=dtype)[0][0.0][0]\n", " return pred_images\n", "\n", "\n", "\n", "def demo_inference_stylization_ori(ref_image, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, start_noise=800):\n", " style_prompt=None\n", " prompts = [prompt]\n", " # convert np to pil\n", " ref_image = [Image.fromarray(ref_image)]\n", " network = get_lora_network(pipe.unet, \"None\", weight_dtype=dtype)[\"network\"]\n", " infer_loader = get_validation_dataloader(prompts, ref_image,num_workers=0)\n", " pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n", " height=512, width=512, scales=[0.0],\n", " save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n", " start_noise=start_noise, show=False, style_prompt=style_prompt, no_load=True,\n", " from_scratch=False, device=device, weight_dtype=dtype)[0][0.0][0]\n", " return pred_images\n", "\n", "\n", "def demo_inference_stylization_artistic(ref_image, adapter_choice:str, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, adapter_scale=1.0,start_noise=800):\n", " adapter_path = lora_map[adapter_choice]\n", " if adapter_path not in [None, \"None\"]:\n", " adapter_path = f\"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n", " style_prompt=\"sks art\"\n", " else:\n", " style_prompt=None\n", " prompts = [prompt]\n", " # convert np to pil\n", " ref_image = [Image.fromarray(ref_image)]\n", " network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)[\"network\"]\n", " infer_loader = get_validation_dataloader(prompts, ref_image,num_workers=0)\n", " pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n", " height=512, width=512, scales=[adapter_scale],\n", " save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n", " start_noise=start_noise, show=False, style_prompt=style_prompt, no_load=True,\n", " from_scratch=False, device=device, weight_dtype=dtype)[0][1.0][0]\n", " return pred_images\n", "\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "aa33e9d104023847", "metadata": { "ExecuteTime": { "end_time": "2024-12-10T02:56:13.419303Z", "start_time": "2024-12-10T02:56:13.002796Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7869\n", "\n", "Thanks for being a Gradio user! If you have questions or feedback, please join our Discord server and chat with us: https://discord.gg/feTf9x3ZSB\n", "Running on public URL: https://0fd0c028b349b76a72.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "