Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import gradio as gr | |
import json | |
import logging | |
import torch | |
from PIL import Image | |
import spaces | |
from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL | |
from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images | |
from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download | |
import copy | |
import random | |
import time | |
from transformers import pipeline | |
import sqlite3 | |
from datetime import datetime | |
# ๋ฐ์ดํฐ๋ฒ ์ด์ค ์ด๊ธฐํ | |
def init_db(): | |
conn = sqlite3.connect('gallery.db') | |
c = conn.cursor() | |
c.execute('''CREATE TABLE IF NOT EXISTS images | |
(id INTEGER PRIMARY KEY AUTOINCREMENT, | |
model_name TEXT, | |
prompt TEXT, | |
image_path TEXT, | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''') | |
conn.commit() | |
conn.close() | |
# ์ด๋ฏธ์ง ์ ์ฅ ํจ์ | |
def save_image(image, model_name, prompt): | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
filename = f"gallery_{timestamp}.png" | |
image.save(os.path.join("gallery", filename)) | |
conn = sqlite3.connect('gallery.db') | |
c = conn.cursor() | |
c.execute("INSERT INTO images (model_name, prompt, image_path) VALUES (?, ?, ?)", | |
(model_name, prompt, filename)) | |
conn.commit() | |
conn.close() | |
# ๊ฐค๋ฌ๋ฆฌ ์ด๋ฏธ์ง ๋ก๋ ํจ์ | |
def load_gallery_images(): | |
conn = sqlite3.connect('gallery.db') | |
c = conn.cursor() | |
c.execute("SELECT model_name, prompt, image_path FROM images ORDER BY created_at DESC") | |
rows = c.fetchall() | |
conn.close() | |
return [(os.path.join("gallery", row[2]), f"{row[0]}: {row[1]}") for row in rows] | |
# CPU์์ ์คํ๋๋ ๋ฒ์ญ๊ธฐ ์ด๊ธฐํ | |
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device=-1) | |
# ํ๋กฌํํธ ์ฒ๋ฆฌ ํจ์ | |
def process_prompt(prompt): | |
if any('\u3131' <= char <= '\u3163' or '\uac00' <= char <= '\ud7a3' for char in prompt): | |
translated = translator(prompt)[0]['translation_text'] | |
return prompt, translated | |
return prompt, prompt | |
KEY_JSON = os.getenv("KEY_JSON") | |
with open(KEY_JSON, 'r') as f: | |
loras = json.load(f) | |
# ๊ธฐ๋ณธ ๋ชจ๋ธ ์ด๊ธฐํ | |
dtype = torch.bfloat16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
base_model = "black-forest-labs/FLUX.1-dev" | |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device) | |
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device) | |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device) | |
MAX_SEED = 2**32-1 | |
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe) | |
class calculateDuration: | |
def __init__(self, activity_name=""): | |
self.activity_name = activity_name | |
def __enter__(self): | |
self.start_time = time.time() | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
self.end_time = time.time() | |
self.elapsed_time = self.end_time - self.start_time | |
if self.activity_name: | |
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds") | |
else: | |
print(f"Elapsed time: {self.elapsed_time:.6f} seconds") | |
def update_selection(evt: gr.SelectData, width, height): | |
selected_lora = loras[evt.index] | |
new_placeholder = f"{selected_lora['title']}๋ฅผ ์ํ ํ๋กฌํํธ๋ฅผ ์ ๋ ฅํ์ธ์" | |
lora_repo = selected_lora["repo"] | |
updated_text = f"### ์ ํ๋จ: [{lora_repo}](https://huggingface.co/{lora_repo}) โจ" | |
if "aspect" in selected_lora: | |
if selected_lora["aspect"] == "portrait": | |
width = 768 | |
height = 1024 | |
elif selected_lora["aspect"] == "landscape": | |
width = 1024 | |
height = 768 | |
else: | |
width = 1024 | |
height = 1024 | |
return ( | |
gr.update(placeholder=new_placeholder), | |
updated_text, | |
evt.index, | |
width, | |
height, | |
) | |
def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress): | |
pipe.to("cuda") | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
with calculateDuration("์ด๋ฏธ์ง ์์ฑ"): | |
# Generate image | |
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images( | |
prompt=prompt_mash, | |
num_inference_steps=steps, | |
guidance_scale=cfg_scale, | |
width=width, | |
height=height, | |
generator=generator, | |
joint_attention_kwargs={"scale": lora_scale}, | |
output_type="pil", | |
good_vae=good_vae, | |
): | |
yield img | |
def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)): | |
if selected_index is None: | |
raise gr.Error("์งํํ๊ธฐ ์ ์ LoRA๋ฅผ ์ ํํด์ผ ํฉ๋๋ค.") | |
original_prompt, english_prompt = process_prompt(prompt) | |
selected_lora = loras[selected_index] | |
lora_path = selected_lora["repo"] | |
trigger_word = selected_lora["trigger_word"] | |
if(trigger_word): | |
if "trigger_position" in selected_lora: | |
if selected_lora["trigger_position"] == "prepend": | |
prompt_mash = f"{trigger_word} {english_prompt}" | |
else: | |
prompt_mash = f"{english_prompt} {trigger_word}" | |
else: | |
prompt_mash = f"{trigger_word} {english_prompt}" | |
else: | |
prompt_mash = english_prompt | |
with calculateDuration("LoRA ์ธ๋ก๋"): | |
pipe.unload_lora_weights() | |
# LoRA ๊ฐ์ค์น ๋ก๋ | |
with calculateDuration(f"{selected_lora['title']}์ LoRA ๊ฐ์ค์น ๋ก๋"): | |
if "weights" in selected_lora: | |
pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"]) | |
else: | |
pipe.load_lora_weights(lora_path) | |
# ์ฌํ์ฑ์ ์ํ ์๋ ์ค์ | |
with calculateDuration("์๋ ๋ฌด์์ํ"): | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress) | |
# ์ต์ข ์ด๋ฏธ์ง๋ฅผ ์ป๊ธฐ ์ํด ์ ๋๋ ์ดํฐ ์๋น | |
final_image = None | |
step_counter = 0 | |
for image in image_generator: | |
step_counter+=1 | |
final_image = image | |
progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>' | |
yield image, seed, gr.update(value=progress_bar, visible=True), original_prompt, english_prompt | |
# ๊ฐค๋ฌ๋ฆฌ์ ์ด๋ฏธ์ง ์ ์ฅ | |
save_image(final_image, selected_lora['title'], original_prompt) | |
yield final_image, seed, gr.update(value=progress_bar, visible=False), original_prompt, english_prompt | |
def get_huggingface_safetensors(link): | |
split_link = link.split("/") | |
if(len(split_link) == 2): | |
model_card = ModelCard.load(link) | |
base_model = model_card.data.get("base_model") | |
print(base_model) | |
if((base_model != "black-forest-labs/FLUX.1-dev") and (base_model != "black-forest-labs/FLUX.1-schnell")): | |
raise Exception("Not a FLUX LoRA!") | |
image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None) | |
trigger_word = model_card.data.get("instance_prompt", "") | |
image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None | |
fs = HfFileSystem() | |
try: | |
list_of_files = fs.ls(link, detail=False) | |
for file in list_of_files: | |
if(file.endswith(".safetensors")): | |
safetensors_name = file.split("/")[-1] | |
if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))): | |
image_elements = file.split("/") | |
image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}" | |
except Exception as e: | |
print(e) | |
gr.Warning(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA") | |
raise Exception(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA") | |
return split_link[1], link, safetensors_name, trigger_word, image_url | |
def check_custom_model(link): | |
if(link.startswith("https://")): | |
if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")): | |
link_split = link.split("huggingface.co/") | |
return get_huggingface_safetensors(link_split[1]) | |
else: | |
return get_huggingface_safetensors(link) | |
def add_custom_lora(custom_lora): | |
global loras | |
if(custom_lora): | |
try: | |
title, repo, path, trigger_word, image = check_custom_model(custom_lora) | |
print(f"Loaded custom LoRA: {repo}") | |
card = f''' | |
<div class="custom_lora_card"> | |
<span>Loaded custom LoRA:</span> | |
<div class="card_internal"> | |
<img src="{image}" /> | |
<div> | |
<h3>{title}</h3> | |
<small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small> | |
</div> | |
</div> | |
</div> | |
''' | |
existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None) | |
if(not existing_item_index): | |
new_item = { | |
"image": image, | |
"title": title, | |
"repo": repo, | |
"weights": path, | |
"trigger_word": trigger_word | |
} | |
print(new_item) | |
existing_item_index = len(loras) | |
loras.append(new_item) | |
return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word | |
except Exception as e: | |
gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-FLUX LoRA") | |
return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-FLUX LoRA"), gr.update(visible=True), gr.update(), "", None, "" | |
else: | |
return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, "" | |
def remove_custom_lora(): | |
return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, "" | |
run_lora.zerogpu = True | |
css = """ | |
footer { | |
visibility: hidden; | |
} | |
""" | |
# ๊ฐค๋ฌ๋ฆฌ ๋๋ ํ ๋ฆฌ ์์ฑ | |
if not os.path.exists('gallery'): | |
os.makedirs('gallery') | |
# ๋ฐ์ดํฐ๋ฒ ์ด์ค ์ด๊ธฐํ | |
init_db() | |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as app: | |
selected_index = gr.State(None) | |
with gr.Tabs(): | |
with gr.TabItem("์์ฑ"): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
prompt = gr.Textbox(label="ํ๋กฌํํธ", lines=1, placeholder="LoRA๋ฅผ ์ ํํ ํ ํ๋กฌํํธ๋ฅผ ์ ๋ ฅํ์ธ์ (ํ๊ธ ๋๋ ์์ด)") | |
with gr.Column(scale=1, elem_id="gen_column"): | |
generate_button = gr.Button("์์ฑ", variant="primary", elem_id="gen_btn") | |
with gr.Row(): | |
with gr.Column(): | |
selected_info = gr.Markdown("") | |
gallery = gr.Gallery( | |
[(item["image"], item["title"]) for item in loras], | |
label="LoRA ๊ฐค๋ฌ๋ฆฌ", | |
allow_preview=False, | |
columns=3, | |
elem_id="gallery" | |
) | |
with gr.Group(): | |
custom_lora = gr.Textbox(label="์ปค์คํ LoRA", info="LoRA Hugging Face ๊ฒฝ๋ก", placeholder="multimodalart/vintage-ads-flux") | |
gr.Markdown("[FLUX LoRA ๋ชฉ๋ก ํ์ธ](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list") | |
custom_lora_info = gr.HTML(visible=False) | |
custom_lora_button = gr.Button("์ปค์คํ LoRA ์ ๊ฑฐ", visible=False) | |
with gr.Column(): | |
progress_bar = gr.Markdown(elem_id="progress",visible=False) | |
result = gr.Image(label="์์ฑ๋ ์ด๋ฏธ์ง") | |
original_prompt_display = gr.Textbox(label="์๋ณธ ํ๋กฌํํธ") | |
english_prompt_display = gr.Textbox(label="์์ด ํ๋กฌํํธ") | |
with gr.Row(): | |
with gr.Accordion("๊ณ ๊ธ ์ค์ ", open=False): | |
with gr.Column(): | |
with gr.Row(): | |
cfg_scale = gr.Slider(label="CFG ์ค์ผ์ผ", minimum=1, maximum=20, step=0.5, value=3.5) | |
steps = gr.Slider(label="์คํ ", minimum=1, maximum=50, step=1, value=28) | |
with gr.Row(): | |
width = gr.Slider(label="๋๋น", minimum=256, maximum=1536, step=64, value=1024) | |
height = gr.Slider(label="๋์ด", minimum=256, maximum=1536, step=64, value=1024) | |
with gr.Row(): | |
randomize_seed = gr.Checkbox(True, label="์๋ ๋ฌด์์ํ") | |
seed = gr.Slider(label="์๋", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True) | |
lora_scale = gr.Slider(label="LoRA ์ค์ผ์ผ", minimum=0, maximum=3, step=0.01, value=0.95) | |
with gr.TabItem("๊ฐค๋ฌ๋ฆฌ"): | |
gallery_images = gr.Gallery( | |
load_gallery_images(), | |
label="์์ฑ๋ ์ด๋ฏธ์ง ๊ฐค๋ฌ๋ฆฌ", | |
columns=3, | |
rows=3, | |
height="auto" | |
) | |
refresh_button = gr.Button("๊ฐค๋ฌ๋ฆฌ ์๋ก๊ณ ์นจ") | |
gallery.select( | |
update_selection, | |
inputs=[width, height], | |
outputs=[prompt, selected_info, selected_index, width, height] | |
) | |
custom_lora.input( | |
add_custom_lora, | |
inputs=[custom_lora], | |
outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt] | |
) | |
custom_lora_button.click( | |
remove_custom_lora, | |
outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora] | |
) | |
gr.on( | |
triggers=[generate_button.click, prompt.submit], | |
fn=run_lora, | |
inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, | |
seed, width, height, lora_scale], | |
outputs=[result, seed, progress_bar, original_prompt_display, english_prompt_display] | |
) | |
refresh_button.click( | |
lambda: gr.update(value=load_gallery_images()), | |
outputs=[gallery_images] | |
) | |
app.queue() | |
app.launch() |