Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import gc | |
import uuid | |
import random | |
import tempfile | |
import time | |
from datetime import datetime | |
from typing import Any | |
from huggingface_hub import login, hf_hub_download | |
import spaces | |
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image, ImageDraw, ImageFont | |
from diffusers import FluxPipeline | |
from transformers import pipeline | |
# 메모리 정리 함수 | |
def clear_memory(): | |
gc.collect() | |
try: | |
if torch.cuda.is_available(): | |
with torch.cuda.device(0): | |
torch.cuda.empty_cache() | |
except: | |
pass | |
# GPU 설정 | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
if torch.cuda.is_available(): | |
try: | |
with torch.cuda.device(0): | |
torch.cuda.empty_cache() | |
torch.backends.cudnn.benchmark = True | |
torch.backends.cuda.matmul.allow_tf32 = True | |
except: | |
print("Warning: Could not configure CUDA settings") | |
# HF 토큰 설정 | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if HF_TOKEN is None: | |
raise ValueError("Please set the HF_TOKEN environment variable") | |
try: | |
login(token=HF_TOKEN) | |
except Exception as e: | |
raise ValueError(f"Failed to login to Hugging Face: {str(e)}") | |
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device=-1) # CPU에서 실행 | |
def translate_to_english(text: str) -> str: | |
"""한글 텍스트를 영어로 번역""" | |
try: | |
if any(ord('가') <= ord(char) <= ord('힣') for char in text): | |
translated = translator(text, max_length=128)[0]['translation_text'] | |
print(f"Translated '{text}' to '{translated}'") | |
return translated | |
return text | |
except Exception as e: | |
print(f"Translation error: {str(e)}") | |
return text | |
# FLUX 파이프라인 초기화 부분 수정 | |
print("Initializing FLUX pipeline...") | |
try: | |
pipe = FluxPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
torch_dtype=torch.float16, | |
use_auth_token=HF_TOKEN, | |
safety_checker=None, | |
device_map="balanced" # 'auto' 대신 'balanced' 사용 | |
) | |
print("FLUX pipeline initialized successfully") | |
# 메모리 최적화 설정 | |
pipe.enable_attention_slicing(slice_size=1) | |
pipe.enable_model_cpu_offload() | |
print("Pipeline optimization settings applied") | |
# 추가 메모리 최적화 | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.backends.cudnn.benchmark = True | |
torch.backends.cuda.matmul.allow_tf32 = True | |
except Exception as e: | |
print(f"Error initializing FLUX pipeline: {str(e)}") | |
raise | |
# LoRA 가중치 로드 부분 | |
print("Loading LoRA weights...") | |
try: | |
lora_path = hf_hub_download( | |
repo_id="openfree/myt-flux-fantasy", | |
filename="myt-flux-fantasy.safetensors", | |
use_auth_token=HF_TOKEN | |
) | |
print(f"LoRA weights downloaded to: {lora_path}") | |
pipe.load_lora_weights(lora_path) | |
pipe.fuse_lora(lora_scale=0.125) | |
print("LoRA weights loaded and fused successfully") | |
except Exception as e: | |
print(f"Error loading LoRA weights: {str(e)}") | |
raise ValueError("Failed to load LoRA weights") | |
# generate_image 함수 수정 | |
def generate_image( | |
prompt: str, | |
seed: int, | |
randomize_seed: bool, | |
width: int, | |
height: int, | |
guidance_scale: float, | |
num_inference_steps: int, | |
progress: gr.Progress = gr.Progress() | |
): | |
try: | |
clear_memory() | |
translated_prompt = translate_to_english(prompt) | |
print(f"Processing prompt: {translated_prompt}") | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator(device=device).manual_seed(seed) | |
with torch.inference_mode(), torch.cuda.amp.autocast(enabled=True): | |
image = pipe( | |
prompt=translated_prompt, | |
width=width, | |
height=height, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
num_images_per_prompt=1, | |
).images[0] | |
filepath = save_generated_image(image, translated_prompt) | |
return image, seed | |
except Exception as e: | |
print(f"Generation error: {str(e)}") | |
raise gr.Error(f"Image generation failed: {str(e)}") | |
finally: | |
clear_memory() | |
# 저장 디렉토리 설정 | |
SAVE_DIR = "saved_images" | |
if not os.path.exists(SAVE_DIR): | |
os.makedirs(SAVE_DIR, exist_ok=True) | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 1024 | |
def save_generated_image(image, prompt): | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
unique_id = str(uuid.uuid4())[:8] | |
filename = f"{timestamp}_{unique_id}.png" | |
filepath = os.path.join(SAVE_DIR, filename) | |
image.save(filepath) | |
return filepath | |
def add_text_with_stroke(draw, text, x, y, font, text_color, stroke_width): | |
"""텍스트에 외곽선을 추가하는 함수""" | |
for adj_x in range(-stroke_width, stroke_width + 1): | |
for adj_y in range(-stroke_width, stroke_width + 1): | |
draw.text((x + adj_x, y + adj_y), text, font=font, fill=text_color) | |
def add_text_to_image( | |
input_image, | |
text, | |
font_size, | |
color, | |
opacity, | |
x_position, | |
y_position, | |
thickness, | |
text_position_type, | |
font_choice | |
): | |
try: | |
if input_image is None or text.strip() == "": | |
return input_image | |
if not isinstance(input_image, Image.Image): | |
if isinstance(input_image, np.ndarray): | |
image = Image.fromarray(input_image) | |
else: | |
raise ValueError("Unsupported image type") | |
else: | |
image = input_image.copy() | |
if image.mode != 'RGBA': | |
image = image.convert('RGBA') | |
font_files = { | |
"Default": "DejaVuSans.ttf", | |
"Korean Regular": "ko-Regular.ttf" | |
} | |
try: | |
font_file = font_files.get(font_choice, "DejaVuSans.ttf") | |
font = ImageFont.truetype(font_file, int(font_size)) | |
except Exception as e: | |
print(f"Font loading error ({font_choice}): {str(e)}") | |
font = ImageFont.load_default() | |
color_map = { | |
'White': (255, 255, 255), | |
'Black': (0, 0, 0), | |
'Red': (255, 0, 0), | |
'Green': (0, 255, 0), | |
'Blue': (0, 0, 255), | |
'Yellow': (255, 255, 0), | |
'Purple': (128, 0, 128) | |
} | |
rgb_color = color_map.get(color, (255, 255, 255)) | |
temp_draw = ImageDraw.Draw(image) | |
text_bbox = temp_draw.textbbox((0, 0), text, font=font) | |
text_width = text_bbox[2] - text_bbox[0] | |
text_height = text_bbox[3] - text_bbox[1] | |
actual_x = int((image.width - text_width) * (x_position / 100)) | |
actual_y = int((image.height - text_height) * (y_position / 100)) | |
text_color = (*rgb_color, int(opacity)) | |
txt_overlay = Image.new('RGBA', image.size, (255, 255, 255, 0)) | |
draw = ImageDraw.Draw(txt_overlay) | |
add_text_with_stroke( | |
draw, | |
text, | |
actual_x, | |
actual_y, | |
font, | |
text_color, | |
int(thickness) | |
) | |
output_image = Image.alpha_composite(image, txt_overlay) | |
output_image = output_image.convert('RGB') | |
return output_image | |
except Exception as e: | |
print(f"Error in add_text_to_image: {str(e)}") | |
return input_image | |
css = """ | |
footer {display: none} | |
.main-title { | |
text-align: center; | |
margin: 1em 0; | |
padding: 1.5em; | |
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); | |
border-radius: 15px; | |
box-shadow: 0 4px 6px rgba(0,0,0,0.1); | |
} | |
.main-title h1 { | |
color: #2196F3; | |
font-size: 2.8em; | |
margin-bottom: 0.3em; | |
font-weight: 700; | |
} | |
.main-title p { | |
color: #555; | |
font-size: 1.3em; | |
line-height: 1.4; | |
} | |
.container { | |
max-width: 1200px; | |
margin: auto; | |
padding: 20px; | |
} | |
.input-panel, .output-panel { | |
background: white; | |
padding: 1.5em; | |
border-radius: 12px; | |
box-shadow: 0 2px 8px rgba(0,0,0,0.08); | |
margin-bottom: 1em; | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
gr.HTML(""" | |
<div class="main-title"> | |
<h1>🎨 Webtoon Canvas</h1> | |
<p>Generate webtoon-style images and add text with various styles and positions.</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# 이미지 생성 섹션 | |
gen_prompt = gr.Textbox( | |
label="Generation Prompt", | |
placeholder="Enter your image generation prompt..." | |
) | |
with gr.Row(): | |
gen_width = gr.Slider(512, 1024, 768, step=64, label="Width") | |
gen_height = gr.Slider(512, 1024, 768, step=64, label="Height") | |
with gr.Row(): | |
guidance_scale = gr.Slider(1, 20, 7.5, step=0.5, label="Guidance Scale") | |
num_steps = gr.Slider(1, 50, 30, step=1, label="Number of Steps") | |
with gr.Row(): | |
seed = gr.Number(label="Seed", value=-1) | |
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) | |
generate_btn = gr.Button("Generate Image", variant="primary") | |
output_image = gr.Image( | |
label="Generated Image", | |
type="pil", | |
show_download_button=True | |
) | |
output_seed = gr.Number(label="Used Seed", interactive=False) | |
# 텍스트 추가 섹션 | |
with gr.Accordion("Text Options", open=False): | |
text_input = gr.Textbox( | |
label="Text Content", | |
placeholder="Enter text to add..." | |
) | |
text_position_type = gr.Radio( | |
choices=["Text Over Image"], | |
value="Text Over Image", | |
label="Text Position", | |
visible=True | |
) | |
with gr.Row(): | |
font_choice = gr.Dropdown( | |
choices=["Default", "Korean Regular"], | |
value="Default", | |
label="Font Selection", | |
interactive=True | |
) | |
font_size = gr.Slider( | |
minimum=10, | |
maximum=200, | |
value=40, | |
step=5, | |
label="Font Size" | |
) | |
with gr.Row(): | |
color_dropdown = gr.Dropdown( | |
choices=["White", "Black", "Red", "Green", "Blue", "Yellow", "Purple"], | |
value="White", | |
label="Text Color" | |
) | |
thickness = gr.Slider( | |
minimum=0, | |
maximum=10, | |
value=1, | |
step=1, | |
label="Text Thickness" | |
) | |
with gr.Row(): | |
opacity_slider = gr.Slider( | |
minimum=0, | |
maximum=255, | |
value=255, | |
step=1, | |
label="Opacity" | |
) | |
with gr.Row(): | |
x_position = gr.Slider( | |
minimum=0, | |
maximum=100, | |
value=50, | |
step=1, | |
label="Left(0%)~Right(100%)" | |
) | |
y_position = gr.Slider( | |
minimum=0, | |
maximum=100, | |
value=50, | |
step=1, | |
label="High(0%)~Low(100%)" | |
) | |
add_text_btn = gr.Button("Apply Text", variant="primary") | |
# 이벤트 바인딩 | |
generate_btn.click( | |
fn=generate_image, | |
inputs=[ | |
gen_prompt, | |
seed, | |
randomize_seed, | |
gen_width, | |
gen_height, | |
guidance_scale, | |
num_steps, | |
], | |
outputs=[output_image, output_seed] | |
) | |
add_text_btn.click( | |
fn=add_text_to_image, | |
inputs=[ | |
output_image, | |
text_input, | |
font_size, | |
color_dropdown, | |
opacity_slider, | |
x_position, | |
y_position, | |
thickness, | |
text_position_type, | |
font_choice | |
], | |
outputs=output_image | |
) | |
demo.queue(max_size=5) | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
max_threads=2 | |
) |