openfree commited on
Commit
5035677
ยท
verified ยท
1 Parent(s): c23ced1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -46
app.py CHANGED
@@ -76,25 +76,29 @@ print("Initializing FLUX pipeline...")
76
  try:
77
  pipe = FluxPipeline.from_pretrained(
78
  "black-forest-labs/FLUX.1-dev",
79
- torch_dtype=torch.float16, # ๋ฐ˜์ •๋ฐ€๋„ ์‚ฌ์šฉ
80
  use_auth_token=HF_TOKEN,
81
  safety_checker=None,
82
- variant="fp16", # fp16 ๋ณ€ํ˜• ์‚ฌ์šฉ
83
- device_map="auto" # ์ž๋™ ์žฅ์น˜ ๋งคํ•‘
84
  )
85
  print("FLUX pipeline initialized successfully")
86
 
87
- # ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™” ์„ค์ • ๊ฐ•ํ™”
88
- pipe.enable_attention_slicing(slice_size=1) # ๋” ์ž‘์€ ์Šฌ๋ผ์ด์Šค ํฌ๊ธฐ
89
- pipe.enable_model_cpu_offload() # CPU ์˜คํ”„๋กœ๋”ฉ ํ™œ์„ฑํ™”
90
- pipe.enable_sequential_cpu_offload() # ์ˆœ์ฐจ์  CPU ์˜คํ”„๋กœ๋”ฉ
91
  print("Pipeline optimization settings applied")
92
 
 
 
 
 
 
 
93
  except Exception as e:
94
  print(f"Error initializing FLUX pipeline: {str(e)}")
95
  raise
96
 
97
- # LoRA ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ๋ถ€๋ถ„ ์ˆ˜์ •
98
  print("Loading LoRA weights...")
99
  try:
100
  lora_path = hf_hub_download(
@@ -104,39 +108,14 @@ try:
104
  )
105
  print(f"LoRA weights downloaded to: {lora_path}")
106
 
107
- # LoRA ๊ฐ€์ค‘์น˜ ๋กœ๋“œ (๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์  ๋ฐฉ์‹)
108
- pipe.load_lora_weights(lora_path, adapter_name="fantasy")
109
  pipe.fuse_lora(lora_scale=0.125)
110
-
111
- # ๋ถˆํ•„์š”ํ•œ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
112
- torch.cuda.empty_cache()
113
- gc.collect()
114
-
115
  print("LoRA weights loaded and fused successfully")
116
 
117
  except Exception as e:
118
  print(f"Error loading LoRA weights: {str(e)}")
119
  raise ValueError("Failed to load LoRA weights")
120
 
121
-
122
-
123
-
124
- # ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ •
125
- SAVE_DIR = "saved_images"
126
- if not os.path.exists(SAVE_DIR):
127
- os.makedirs(SAVE_DIR, exist_ok=True)
128
-
129
- MAX_SEED = np.iinfo(np.int32).max
130
- MAX_IMAGE_SIZE = 1024
131
-
132
- def save_generated_image(image, prompt):
133
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
134
- unique_id = str(uuid.uuid4())[:8]
135
- filename = f"{timestamp}_{unique_id}.png"
136
- filepath = os.path.join(SAVE_DIR, filename)
137
- image.save(filepath)
138
- return filepath
139
-
140
  # generate_image ํ•จ์ˆ˜ ์ˆ˜์ •
141
  @spaces.GPU(duration=60)
142
  def generate_image(
@@ -150,21 +129,17 @@ def generate_image(
150
  progress: gr.Progress = gr.Progress()
151
  ):
152
  try:
153
- print(f"\nStarting image generation with prompt: {prompt}")
154
-
155
- # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
156
  clear_memory()
157
 
158
  translated_prompt = translate_to_english(prompt)
159
- print(f"Translated prompt: {translated_prompt}")
160
 
161
  if randomize_seed:
162
  seed = random.randint(0, MAX_SEED)
163
 
164
  generator = torch.Generator(device=device).manual_seed(seed)
165
 
166
- # ๋ฐฐ์น˜ ํฌ๊ธฐ 1๋กœ ๊ณ ์ •ํ•˜์—ฌ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ ์ตœ์†Œํ™”
167
- with torch.inference_mode(), torch.cuda.amp.autocast():
168
  image = pipe(
169
  prompt=translated_prompt,
170
  width=width,
@@ -176,16 +151,34 @@ def generate_image(
176
  ).images[0]
177
 
178
  filepath = save_generated_image(image, translated_prompt)
179
-
180
- # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
181
- clear_memory()
182
-
183
  return image, seed
184
 
185
  except Exception as e:
186
- print(f"Error in generate_image: {str(e)}")
187
- clear_memory()
188
  raise gr.Error(f"Image generation failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  def add_text_with_stroke(draw, text, x, y, font, text_color, stroke_width):
191
  """ํ…์ŠคํŠธ์— ์™ธ๊ณฝ์„ ์„ ์ถ”๊ฐ€ํ•˜๋Š” ํ•จ์ˆ˜"""
 
76
  try:
77
  pipe = FluxPipeline.from_pretrained(
78
  "black-forest-labs/FLUX.1-dev",
79
+ torch_dtype=torch.float16,
80
  use_auth_token=HF_TOKEN,
81
  safety_checker=None,
82
+ device_map="balanced" # 'auto' ๋Œ€์‹  'balanced' ์‚ฌ์šฉ
 
83
  )
84
  print("FLUX pipeline initialized successfully")
85
 
86
+ # ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™” ์„ค์ •
87
+ pipe.enable_attention_slicing(slice_size=1)
88
+ pipe.enable_model_cpu_offload()
 
89
  print("Pipeline optimization settings applied")
90
 
91
+ # ์ถ”๊ฐ€ ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™”
92
+ if torch.cuda.is_available():
93
+ torch.cuda.empty_cache()
94
+ torch.backends.cudnn.benchmark = True
95
+ torch.backends.cuda.matmul.allow_tf32 = True
96
+
97
  except Exception as e:
98
  print(f"Error initializing FLUX pipeline: {str(e)}")
99
  raise
100
 
101
+ # LoRA ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ๋ถ€๋ถ„
102
  print("Loading LoRA weights...")
103
  try:
104
  lora_path = hf_hub_download(
 
108
  )
109
  print(f"LoRA weights downloaded to: {lora_path}")
110
 
111
+ pipe.load_lora_weights(lora_path)
 
112
  pipe.fuse_lora(lora_scale=0.125)
 
 
 
 
 
113
  print("LoRA weights loaded and fused successfully")
114
 
115
  except Exception as e:
116
  print(f"Error loading LoRA weights: {str(e)}")
117
  raise ValueError("Failed to load LoRA weights")
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  # generate_image ํ•จ์ˆ˜ ์ˆ˜์ •
120
  @spaces.GPU(duration=60)
121
  def generate_image(
 
129
  progress: gr.Progress = gr.Progress()
130
  ):
131
  try:
 
 
 
132
  clear_memory()
133
 
134
  translated_prompt = translate_to_english(prompt)
135
+ print(f"Processing prompt: {translated_prompt}")
136
 
137
  if randomize_seed:
138
  seed = random.randint(0, MAX_SEED)
139
 
140
  generator = torch.Generator(device=device).manual_seed(seed)
141
 
142
+ with torch.inference_mode(), torch.cuda.amp.autocast(enabled=True):
 
143
  image = pipe(
144
  prompt=translated_prompt,
145
  width=width,
 
151
  ).images[0]
152
 
153
  filepath = save_generated_image(image, translated_prompt)
 
 
 
 
154
  return image, seed
155
 
156
  except Exception as e:
157
+ print(f"Generation error: {str(e)}")
 
158
  raise gr.Error(f"Image generation failed: {str(e)}")
159
+ finally:
160
+ clear_memory()
161
+
162
+
163
+
164
+
165
+ # ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ •
166
+ SAVE_DIR = "saved_images"
167
+ if not os.path.exists(SAVE_DIR):
168
+ os.makedirs(SAVE_DIR, exist_ok=True)
169
+
170
+ MAX_SEED = np.iinfo(np.int32).max
171
+ MAX_IMAGE_SIZE = 1024
172
+
173
+ def save_generated_image(image, prompt):
174
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
175
+ unique_id = str(uuid.uuid4())[:8]
176
+ filename = f"{timestamp}_{unique_id}.png"
177
+ filepath = os.path.join(SAVE_DIR, filename)
178
+ image.save(filepath)
179
+ return filepath
180
+
181
+
182
 
183
  def add_text_with_stroke(draw, text, x, y, font, text_color, stroke_width):
184
  """ํ…์ŠคํŠธ์— ์™ธ๊ณฝ์„ ์„ ์ถ”๊ฐ€ํ•˜๋Š” ํ•จ์ˆ˜"""