Spaces:
Sleeping
Sleeping
Jose Benitez
commited on
Commit
·
a1e077b
1
Parent(s):
afce03d
updates
Browse files- database.py +24 -2
- gradio_app.py +63 -23
- services/image_generation.py +27 -14
- services/train_lora.py +11 -5
database.py
CHANGED
@@ -35,11 +35,33 @@ def get_or_create_user(google_id, email, name, given_name, profile_picture):
|
|
35 |
return user.data[0]
|
36 |
|
37 |
def get_lora_models_info():
|
38 |
-
lora_models = supabase.table("lora_models").select("*").execute()
|
|
|
39 |
return lora_models.data
|
40 |
|
41 |
def get_user_by_id(user_id):
|
42 |
user = supabase.table("users").select("*").eq("id", user_id).execute()
|
43 |
if user.data:
|
44 |
return user.data[0]
|
45 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
return user.data[0]
|
36 |
|
37 |
def get_lora_models_info():
|
38 |
+
lora_models = supabase.table("lora_models").select("*").is_("user_id", None).execute()
|
39 |
+
|
40 |
return lora_models.data
|
41 |
|
42 |
def get_user_by_id(user_id):
|
43 |
user = supabase.table("users").select("*").eq("id", user_id).execute()
|
44 |
if user.data:
|
45 |
return user.data[0]
|
46 |
+
return None
|
47 |
+
|
48 |
+
def create_lora_models(user_id, replicate_repo_name, trigger_word, steps, lora_rank, batch_size, learning_rate, hf_repo_name, training_url):
|
49 |
+
# create a jsonb from trigger_word, train_steps, lora_rank, batch_size, learning_rate values
|
50 |
+
model_config = {
|
51 |
+
"train_steps": steps,
|
52 |
+
"lora_rank": lora_rank,
|
53 |
+
"batch_size": batch_size,
|
54 |
+
"learning_rate": learning_rate
|
55 |
+
}
|
56 |
+
result = supabase.table("lora_models").insert({
|
57 |
+
"user_id": user_id,
|
58 |
+
"trigger_word": trigger_word,
|
59 |
+
"lora_name": replicate_repo_name,
|
60 |
+
"hf_repo": hf_repo_name,
|
61 |
+
"configs": model_config,
|
62 |
+
"training_url": training_url
|
63 |
+
}).execute()
|
64 |
+
|
65 |
+
def get_user_lora_models(user_id):
|
66 |
+
user_models = supabase.table("lora_models").select("*").eq("user_id", user_id).execute()
|
67 |
+
return user_models.data
|
gradio_app.py
CHANGED
@@ -5,14 +5,13 @@ import json
|
|
5 |
import zipfile
|
6 |
from pathlib import Path
|
7 |
|
8 |
-
from database import get_user_credits, update_user_credits, get_lora_models_info
|
9 |
from services.image_generation import generate_image
|
10 |
from services.train_lora import lora_pipeline
|
11 |
from utils.image_utils import url_to_pil_image
|
12 |
|
13 |
lora_models = get_lora_models_info()
|
14 |
|
15 |
-
|
16 |
if not isinstance(lora_models, list):
|
17 |
raise ValueError("Expected loras_models to be a list of dictionaries.")
|
18 |
|
@@ -37,8 +36,20 @@ if main_header_path.is_file():
|
|
37 |
with main_header_path.open() as file:
|
38 |
main_header = file.read()
|
39 |
|
40 |
-
def
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
new_placeholder = f"Ingresa un prompt para tu modelo {selected_lora['lora_name']}"
|
43 |
trigger_word = selected_lora["trigger_word"]
|
44 |
updated_text = f"#### Palabra clave: {trigger_word} ✨"
|
@@ -48,13 +59,18 @@ def update_selection(evt: gr.SelectData, width, height):
|
|
48 |
width, height = 768, 1024
|
49 |
elif selected_lora["aspect"] == "landscape":
|
50 |
width, height = 1024, 768
|
51 |
-
|
52 |
-
return gr.update(placeholder=new_placeholder), updated_text, evt.index, width, height
|
53 |
|
54 |
-
|
|
|
|
|
55 |
if not files:
|
56 |
-
return "No
|
|
|
|
|
|
|
|
|
57 |
|
|
|
58 |
# Create a directory in the user's home folder
|
59 |
output_dir = os.path.expanduser("~/gradio_training_data")
|
60 |
os.makedirs(output_dir, exist_ok=True)
|
@@ -72,7 +88,8 @@ def compress_and_train(files, model_name, trigger_word, train_steps, lora_rank,
|
|
72 |
|
73 |
print(f'[INFO] Procesando {trigger_word}')
|
74 |
# Now call the train_lora function with the zip file path
|
75 |
-
result = lora_pipeline(
|
|
|
76 |
model_name,
|
77 |
trigger_word=trigger_word,
|
78 |
steps=train_steps,
|
@@ -81,19 +98,39 @@ def compress_and_train(files, model_name, trigger_word, train_steps, lora_rank,
|
|
81 |
autocaption=True,
|
82 |
learning_rate=learning_rate)
|
83 |
|
84 |
-
return
|
85 |
|
86 |
-
def run_lora(request: gr.Request, prompt, cfg_scale, steps, selected_index,
|
87 |
user = request.session.get('user')
|
88 |
if not user:
|
89 |
raise gr.Error("User not authenticated. Please log in.")
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
generation_credits, _ = get_user_credits(user['id'])
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
if generation_credits <= 0:
|
94 |
raise gr.Error("Ya no tienes creditos disponibles. Compra para continuar.")
|
95 |
|
96 |
-
image_url = generate_image(prompt, steps, cfg_scale, width, height, lora_scale, progress)
|
97 |
image = url_to_pil_image(image_url)
|
98 |
|
99 |
# Update user's credits
|
@@ -193,7 +230,7 @@ with gr.Blocks(theme=gr.themes.Soft(), head=header, css=main_css) as main_demo:
|
|
193 |
columns=3,
|
194 |
elem_id="gallery"
|
195 |
)
|
196 |
-
|
197 |
|
198 |
with gr.Accordion("Configuracion Avanzada", open=False):
|
199 |
with gr.Row():
|
@@ -208,14 +245,19 @@ with gr.Blocks(theme=gr.themes.Soft(), head=header, css=main_css) as main_demo:
|
|
208 |
|
209 |
gallery.select(
|
210 |
update_selection,
|
211 |
-
inputs=[width, height],
|
212 |
-
outputs=[prompt, selected_info, selected_index, width, height]
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
)
|
214 |
-
|
215 |
gr.on(
|
216 |
triggers=[generate_button.click, prompt.submit],
|
217 |
fn=run_lora,
|
218 |
-
inputs=[prompt, cfg_scale, steps, selected_index,
|
219 |
outputs=[result, generation_credits_display]
|
220 |
)
|
221 |
|
@@ -237,8 +279,6 @@ with gr.Blocks(theme=gr.themes.Soft(), head=header, css=main_css) as main_demo:
|
|
237 |
learning_rate = gr.Number(label='learning_rate', value=0.0004)
|
238 |
training_status = gr.Textbox(label="Training Status")
|
239 |
|
240 |
-
|
241 |
-
|
242 |
train_button.click(
|
243 |
compress_and_train,
|
244 |
inputs=[train_dataset, model_name, trigger_word, train_steps, lora_rank, batch_size, learning_rate],
|
@@ -249,14 +289,14 @@ with gr.Blocks(theme=gr.themes.Soft(), head=header, css=main_css) as main_demo:
|
|
249 |
#main_demo.load(greet, None, title)
|
250 |
#main_demo.load(greet, None, greetings)
|
251 |
#main_demo.load((greet, display_credits), None, [greetings, generation_credits_display, train_credits_display])
|
|
|
252 |
main_demo.load(load_greet_and_credits, None, [greetings, generation_credits_display, train_credits_display])
|
253 |
|
254 |
|
255 |
|
256 |
# TODO:
|
257 |
'''
|
258 |
-
-
|
259 |
-
- Galeria Modelos Open Source (accordion)
|
260 |
- Training con creditos.
|
261 |
- Stripe(?)
|
262 |
- Mejorar boton de login/logout
|
|
|
5 |
import zipfile
|
6 |
from pathlib import Path
|
7 |
|
8 |
+
from database import get_user_credits, update_user_credits, get_lora_models_info, get_user_lora_models
|
9 |
from services.image_generation import generate_image
|
10 |
from services.train_lora import lora_pipeline
|
11 |
from utils.image_utils import url_to_pil_image
|
12 |
|
13 |
lora_models = get_lora_models_info()
|
14 |
|
|
|
15 |
if not isinstance(lora_models, list):
|
16 |
raise ValueError("Expected loras_models to be a list of dictionaries.")
|
17 |
|
|
|
36 |
with main_header_path.open() as file:
|
37 |
main_header = file.read()
|
38 |
|
39 |
+
def load_user_models(request: gr.Request):
|
40 |
+
user = request.session.get('user')
|
41 |
+
print(user)
|
42 |
+
if user:
|
43 |
+
user_models = get_user_lora_models(user['id'])
|
44 |
+
if user_models:
|
45 |
+
return [(item.get("image_url", "assets/logo.jpg"), item["lora_name"]) for item in user_models]
|
46 |
+
return []
|
47 |
+
|
48 |
+
def update_selection(evt: gr.SelectData, gallery_type: str, width, height):
|
49 |
+
if gallery_type == "user":
|
50 |
+
selected_lora = {"lora_name": "custom", "trigger_word": "custom"}
|
51 |
+
else:
|
52 |
+
selected_lora = lora_models[evt.index]
|
53 |
new_placeholder = f"Ingresa un prompt para tu modelo {selected_lora['lora_name']}"
|
54 |
trigger_word = selected_lora["trigger_word"]
|
55 |
updated_text = f"#### Palabra clave: {trigger_word} ✨"
|
|
|
59 |
width, height = 768, 1024
|
60 |
elif selected_lora["aspect"] == "landscape":
|
61 |
width, height = 1024, 768
|
|
|
|
|
62 |
|
63 |
+
return gr.update(placeholder=new_placeholder), updated_text, evt.index, width, height, gallery_type
|
64 |
+
|
65 |
+
def compress_and_train(request: gr.Request, files, model_name, trigger_word, train_steps, lora_rank, batch_size, learning_rate):
|
66 |
if not files:
|
67 |
+
return "No hay imágenes. Sube algunas imágenes para poder entrenar."
|
68 |
+
|
69 |
+
user = request.session.get('user')
|
70 |
+
if not user:
|
71 |
+
raise gr.Error("User not authenticated. Please log in.")
|
72 |
|
73 |
+
user_id = user['id']
|
74 |
# Create a directory in the user's home folder
|
75 |
output_dir = os.path.expanduser("~/gradio_training_data")
|
76 |
os.makedirs(output_dir, exist_ok=True)
|
|
|
88 |
|
89 |
print(f'[INFO] Procesando {trigger_word}')
|
90 |
# Now call the train_lora function with the zip file path
|
91 |
+
result = lora_pipeline(user_id,
|
92 |
+
zip_path,
|
93 |
model_name,
|
94 |
trigger_word=trigger_word,
|
95 |
steps=train_steps,
|
|
|
98 |
autocaption=True,
|
99 |
learning_rate=learning_rate)
|
100 |
|
101 |
+
return gr.Info("Tu modelo esta entrenando, En unos 20 minutos estará listo para que lo pruebes en 'Generación'.")
|
102 |
|
103 |
+
def run_lora(request: gr.Request, prompt, cfg_scale, steps, selected_index, selected_gallery, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
|
104 |
user = request.session.get('user')
|
105 |
if not user:
|
106 |
raise gr.Error("User not authenticated. Please log in.")
|
107 |
+
lora_models = get_user_lora_models(user['id'])
|
108 |
+
print(f'Selected gallery: {selected_gallery}')
|
109 |
+
if selected_gallery == "user":
|
110 |
+
lora_models = get_user_lora_models(user['id'])
|
111 |
+
print('Using user models')
|
112 |
+
else: # public
|
113 |
+
lora_models = get_lora_models_info()
|
114 |
+
print('Using public models')
|
115 |
+
print(f'Selected index: {selected_index}')
|
116 |
+
if selected_index is None:
|
117 |
+
selected_lora = None
|
118 |
+
else:
|
119 |
+
selected_lora = lora_models[selected_index]
|
120 |
+
|
121 |
generation_credits, _ = get_user_credits(user['id'])
|
122 |
+
if selected_lora:
|
123 |
+
print(f"Selected Lora: {selected_lora['lora_name']}")
|
124 |
+
model_name = selected_lora['lora_name']
|
125 |
+
use_default = False
|
126 |
+
else:
|
127 |
+
model_name = "black-forest-labs/flux-pro"
|
128 |
+
print(f"Using default Lora: {model_name}")
|
129 |
+
use_default = True
|
130 |
if generation_credits <= 0:
|
131 |
raise gr.Error("Ya no tienes creditos disponibles. Compra para continuar.")
|
132 |
|
133 |
+
image_url = generate_image(model_name, prompt, steps, cfg_scale, width, height, lora_scale, progress, use_default)
|
134 |
image = url_to_pil_image(image_url)
|
135 |
|
136 |
# Update user's credits
|
|
|
230 |
columns=3,
|
231 |
elem_id="gallery"
|
232 |
)
|
233 |
+
gallery_type = gr.State("Public")
|
234 |
|
235 |
with gr.Accordion("Configuracion Avanzada", open=False):
|
236 |
with gr.Row():
|
|
|
245 |
|
246 |
gallery.select(
|
247 |
update_selection,
|
248 |
+
inputs=[gr.State("public"), width, height],
|
249 |
+
outputs=[prompt, selected_info, selected_index, width, height, gallery_type]
|
250 |
+
)
|
251 |
+
|
252 |
+
user_model_gallery.select(
|
253 |
+
update_selection,
|
254 |
+
inputs=[gr.State("user"), width, height],
|
255 |
+
outputs=[prompt, selected_info, selected_index, width, height, gallery_type]
|
256 |
)
|
|
|
257 |
gr.on(
|
258 |
triggers=[generate_button.click, prompt.submit],
|
259 |
fn=run_lora,
|
260 |
+
inputs=[prompt, cfg_scale, steps, selected_index, gallery_type, width, height, lora_scale],
|
261 |
outputs=[result, generation_credits_display]
|
262 |
)
|
263 |
|
|
|
279 |
learning_rate = gr.Number(label='learning_rate', value=0.0004)
|
280 |
training_status = gr.Textbox(label="Training Status")
|
281 |
|
|
|
|
|
282 |
train_button.click(
|
283 |
compress_and_train,
|
284 |
inputs=[train_dataset, model_name, trigger_word, train_steps, lora_rank, batch_size, learning_rate],
|
|
|
289 |
#main_demo.load(greet, None, title)
|
290 |
#main_demo.load(greet, None, greetings)
|
291 |
#main_demo.load((greet, display_credits), None, [greetings, generation_credits_display, train_credits_display])
|
292 |
+
main_demo.load(load_user_models, None, user_model_gallery)
|
293 |
main_demo.load(load_greet_and_credits, None, [greetings, generation_credits_display, train_credits_display])
|
294 |
|
295 |
|
296 |
|
297 |
# TODO:
|
298 |
'''
|
299 |
+
- resolver mostrar bien los nombres de los modelos en la galeria
|
|
|
300 |
- Training con creditos.
|
301 |
- Stripe(?)
|
302 |
- Mejorar boton de login/logout
|
services/image_generation.py
CHANGED
@@ -3,19 +3,32 @@ from PIL import Image
|
|
3 |
import requests
|
4 |
from io import BytesIO
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
def generate_image(prompt, steps, cfg_scale, width, height, lora_scale, progress, trigger_word='hi'):
|
9 |
print(f"Generating image for prompt: {prompt}")
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
return img_url
|
|
|
3 |
import requests
|
4 |
from io import BytesIO
|
5 |
|
6 |
+
|
7 |
+
def generate_image(model_name, prompt, steps, cfg_scale, width, height, lora_scale, progress, use_default=False, trigger_word='hi'):
|
|
|
8 |
print(f"Generating image for prompt: {prompt}")
|
9 |
+
if use_default:
|
10 |
+
img_url = replicate.run(
|
11 |
+
"black-forest-labs/flux-pro",
|
12 |
+
input={
|
13 |
+
"steps": steps,
|
14 |
+
"prompt": prompt,
|
15 |
+
"guidance": cfg_scale,
|
16 |
+
"interval": 2,
|
17 |
+
"aspect_ratio": "1:1",
|
18 |
+
"safety_tolerance": 2
|
19 |
+
}
|
20 |
+
)
|
21 |
+
else:
|
22 |
+
img_url = replicate.run(
|
23 |
+
model_name,
|
24 |
+
input={
|
25 |
+
"model": "dev",
|
26 |
+
"steps": steps,
|
27 |
+
"prompt": prompt,
|
28 |
+
"guidance": cfg_scale,
|
29 |
+
"interval": 2,
|
30 |
+
"aspect_ratio": "1:1",
|
31 |
+
"safety_tolerance": 2
|
32 |
+
}
|
33 |
+
)
|
34 |
return img_url
|
services/train_lora.py
CHANGED
@@ -1,20 +1,22 @@
|
|
1 |
import replicate
|
2 |
import os
|
3 |
from huggingface_hub import create_repo
|
|
|
4 |
|
5 |
REPLICATE_OWNER = "josebenitezg"
|
6 |
|
7 |
-
def lora_pipeline(zip_path, model_name, trigger_word="TOK", steps=1000, lora_rank=16, batch_size=1, autocaption=True, learning_rate=0.0004):
|
8 |
print(f'Creating dataset for {model_name}')
|
9 |
-
|
10 |
-
|
|
|
11 |
|
12 |
lora_name = f"flux-dev-{model_name}"
|
13 |
|
14 |
model = replicate.models.create(
|
15 |
owner=REPLICATE_OWNER,
|
16 |
name=lora_name,
|
17 |
-
visibility="
|
18 |
hardware="gpu-t4", # Replicate will override this for fine-tuned models
|
19 |
description="A fine-tuned FLUX.1 model"
|
20 |
)
|
@@ -37,10 +39,14 @@ def lora_pipeline(zip_path, model_name, trigger_word="TOK", steps=1000, lora_ran
|
|
37 |
"trigger_word": trigger_word,
|
38 |
"learning_rate": learning_rate,
|
39 |
"hf_token": os.getenv('HF_TOKEN'), # optional
|
40 |
-
"hf_repo_id":
|
41 |
},
|
42 |
destination=f"{model.owner}/{model.name}"
|
43 |
)
|
44 |
|
|
|
45 |
print(f"Training started: {training.status}")
|
46 |
print(f"Training URL: https://replicate.com/p/{training.id}")
|
|
|
|
|
|
|
|
1 |
import replicate
|
2 |
import os
|
3 |
from huggingface_hub import create_repo
|
4 |
+
from database import create_lora_models
|
5 |
|
6 |
REPLICATE_OWNER = "josebenitezg"
|
7 |
|
8 |
+
def lora_pipeline(user_id, zip_path, model_name, trigger_word="TOK", steps=1000, lora_rank=16, batch_size=1, autocaption=True, learning_rate=0.0004):
|
9 |
print(f'Creating dataset for {model_name}')
|
10 |
+
hf_repo_name = f"joselobenitezg/flux-dev-{model_name}"
|
11 |
+
replicate_repo_name = f"josebenitezg/flux-dev-{model_name}"
|
12 |
+
create_repo(hf_repo_name, repo_type='model')
|
13 |
|
14 |
lora_name = f"flux-dev-{model_name}"
|
15 |
|
16 |
model = replicate.models.create(
|
17 |
owner=REPLICATE_OWNER,
|
18 |
name=lora_name,
|
19 |
+
visibility="private", # or "private" if you prefer
|
20 |
hardware="gpu-t4", # Replicate will override this for fine-tuned models
|
21 |
description="A fine-tuned FLUX.1 model"
|
22 |
)
|
|
|
39 |
"trigger_word": trigger_word,
|
40 |
"learning_rate": learning_rate,
|
41 |
"hf_token": os.getenv('HF_TOKEN'), # optional
|
42 |
+
"hf_repo_id": hf_repo_name, # optional
|
43 |
},
|
44 |
destination=f"{model.owner}/{model.name}"
|
45 |
)
|
46 |
|
47 |
+
print(f"training: {training.keys()}")
|
48 |
print(f"Training started: {training.status}")
|
49 |
print(f"Training URL: https://replicate.com/p/{training.id}")
|
50 |
+
print(f"Creating model in Database")
|
51 |
+
training_url = f"https://replicate.com/p/{training.id}"
|
52 |
+
create_lora_models(user_id, replicate_repo_name, trigger_word, steps, lora_rank, batch_size, learning_rate, hf_repo_name, training_url)
|