Jose Benitez commited on
Commit
a1e077b
·
1 Parent(s): afce03d
Files changed (4) hide show
  1. database.py +24 -2
  2. gradio_app.py +63 -23
  3. services/image_generation.py +27 -14
  4. services/train_lora.py +11 -5
database.py CHANGED
@@ -35,11 +35,33 @@ def get_or_create_user(google_id, email, name, given_name, profile_picture):
35
  return user.data[0]
36
 
37
  def get_lora_models_info():
38
- lora_models = supabase.table("lora_models").select("*").execute()
 
39
  return lora_models.data
40
 
41
  def get_user_by_id(user_id):
42
  user = supabase.table("users").select("*").eq("id", user_id).execute()
43
  if user.data:
44
  return user.data[0]
45
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  return user.data[0]
36
 
37
  def get_lora_models_info():
38
+ lora_models = supabase.table("lora_models").select("*").is_("user_id", None).execute()
39
+
40
  return lora_models.data
41
 
42
  def get_user_by_id(user_id):
43
  user = supabase.table("users").select("*").eq("id", user_id).execute()
44
  if user.data:
45
  return user.data[0]
46
+ return None
47
+
48
+ def create_lora_models(user_id, replicate_repo_name, trigger_word, steps, lora_rank, batch_size, learning_rate, hf_repo_name, training_url):
49
+ # create a jsonb from trigger_word, train_steps, lora_rank, batch_size, learning_rate values
50
+ model_config = {
51
+ "train_steps": steps,
52
+ "lora_rank": lora_rank,
53
+ "batch_size": batch_size,
54
+ "learning_rate": learning_rate
55
+ }
56
+ result = supabase.table("lora_models").insert({
57
+ "user_id": user_id,
58
+ "trigger_word": trigger_word,
59
+ "lora_name": replicate_repo_name,
60
+ "hf_repo": hf_repo_name,
61
+ "configs": model_config,
62
+ "training_url": training_url
63
+ }).execute()
64
+
65
+ def get_user_lora_models(user_id):
66
+ user_models = supabase.table("lora_models").select("*").eq("user_id", user_id).execute()
67
+ return user_models.data
gradio_app.py CHANGED
@@ -5,14 +5,13 @@ import json
5
  import zipfile
6
  from pathlib import Path
7
 
8
- from database import get_user_credits, update_user_credits, get_lora_models_info
9
  from services.image_generation import generate_image
10
  from services.train_lora import lora_pipeline
11
  from utils.image_utils import url_to_pil_image
12
 
13
  lora_models = get_lora_models_info()
14
 
15
-
16
  if not isinstance(lora_models, list):
17
  raise ValueError("Expected loras_models to be a list of dictionaries.")
18
 
@@ -37,8 +36,20 @@ if main_header_path.is_file():
37
  with main_header_path.open() as file:
38
  main_header = file.read()
39
 
40
- def update_selection(evt: gr.SelectData, width, height):
41
- selected_lora = lora_models[evt.index]
 
 
 
 
 
 
 
 
 
 
 
 
42
  new_placeholder = f"Ingresa un prompt para tu modelo {selected_lora['lora_name']}"
43
  trigger_word = selected_lora["trigger_word"]
44
  updated_text = f"#### Palabra clave: {trigger_word} ✨"
@@ -48,13 +59,18 @@ def update_selection(evt: gr.SelectData, width, height):
48
  width, height = 768, 1024
49
  elif selected_lora["aspect"] == "landscape":
50
  width, height = 1024, 768
51
-
52
- return gr.update(placeholder=new_placeholder), updated_text, evt.index, width, height
53
 
54
- def compress_and_train(files, model_name, trigger_word, train_steps, lora_rank, batch_size, learning_rate):
 
 
55
  if not files:
56
- return "No images uploaded. Please upload images before training."
 
 
 
 
57
 
 
58
  # Create a directory in the user's home folder
59
  output_dir = os.path.expanduser("~/gradio_training_data")
60
  os.makedirs(output_dir, exist_ok=True)
@@ -72,7 +88,8 @@ def compress_and_train(files, model_name, trigger_word, train_steps, lora_rank,
72
 
73
  print(f'[INFO] Procesando {trigger_word}')
74
  # Now call the train_lora function with the zip file path
75
- result = lora_pipeline(zip_path,
 
76
  model_name,
77
  trigger_word=trigger_word,
78
  steps=train_steps,
@@ -81,19 +98,39 @@ def compress_and_train(files, model_name, trigger_word, train_steps, lora_rank,
81
  autocaption=True,
82
  learning_rate=learning_rate)
83
 
84
- return f"{result}\n\nZip file saved at: {zip_path}"
85
 
86
- def run_lora(request: gr.Request, prompt, cfg_scale, steps, selected_index, randomize_seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
87
  user = request.session.get('user')
88
  if not user:
89
  raise gr.Error("User not authenticated. Please log in.")
90
-
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  generation_credits, _ = get_user_credits(user['id'])
92
-
 
 
 
 
 
 
 
93
  if generation_credits <= 0:
94
  raise gr.Error("Ya no tienes creditos disponibles. Compra para continuar.")
95
 
96
- image_url = generate_image(prompt, steps, cfg_scale, width, height, lora_scale, progress)
97
  image = url_to_pil_image(image_url)
98
 
99
  # Update user's credits
@@ -193,7 +230,7 @@ with gr.Blocks(theme=gr.themes.Soft(), head=header, css=main_css) as main_demo:
193
  columns=3,
194
  elem_id="gallery"
195
  )
196
-
197
 
198
  with gr.Accordion("Configuracion Avanzada", open=False):
199
  with gr.Row():
@@ -208,14 +245,19 @@ with gr.Blocks(theme=gr.themes.Soft(), head=header, css=main_css) as main_demo:
208
 
209
  gallery.select(
210
  update_selection,
211
- inputs=[width, height],
212
- outputs=[prompt, selected_info, selected_index, width, height]
 
 
 
 
 
 
213
  )
214
-
215
  gr.on(
216
  triggers=[generate_button.click, prompt.submit],
217
  fn=run_lora,
218
- inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, width, height, lora_scale],
219
  outputs=[result, generation_credits_display]
220
  )
221
 
@@ -237,8 +279,6 @@ with gr.Blocks(theme=gr.themes.Soft(), head=header, css=main_css) as main_demo:
237
  learning_rate = gr.Number(label='learning_rate', value=0.0004)
238
  training_status = gr.Textbox(label="Training Status")
239
 
240
-
241
-
242
  train_button.click(
243
  compress_and_train,
244
  inputs=[train_dataset, model_name, trigger_word, train_steps, lora_rank, batch_size, learning_rate],
@@ -249,14 +289,14 @@ with gr.Blocks(theme=gr.themes.Soft(), head=header, css=main_css) as main_demo:
249
  #main_demo.load(greet, None, title)
250
  #main_demo.load(greet, None, greetings)
251
  #main_demo.load((greet, display_credits), None, [greetings, generation_credits_display, train_credits_display])
 
252
  main_demo.load(load_greet_and_credits, None, [greetings, generation_credits_display, train_credits_display])
253
 
254
 
255
 
256
  # TODO:
257
  '''
258
- - Galeria Modelos Propios (si existe alguno del user, si no, mostrar un mensaje para entrenar)
259
- - Galeria Modelos Open Source (accordion)
260
  - Training con creditos.
261
  - Stripe(?)
262
  - Mejorar boton de login/logout
 
5
  import zipfile
6
  from pathlib import Path
7
 
8
+ from database import get_user_credits, update_user_credits, get_lora_models_info, get_user_lora_models
9
  from services.image_generation import generate_image
10
  from services.train_lora import lora_pipeline
11
  from utils.image_utils import url_to_pil_image
12
 
13
  lora_models = get_lora_models_info()
14
 
 
15
  if not isinstance(lora_models, list):
16
  raise ValueError("Expected loras_models to be a list of dictionaries.")
17
 
 
36
  with main_header_path.open() as file:
37
  main_header = file.read()
38
 
39
+ def load_user_models(request: gr.Request):
40
+ user = request.session.get('user')
41
+ print(user)
42
+ if user:
43
+ user_models = get_user_lora_models(user['id'])
44
+ if user_models:
45
+ return [(item.get("image_url", "assets/logo.jpg"), item["lora_name"]) for item in user_models]
46
+ return []
47
+
48
+ def update_selection(evt: gr.SelectData, gallery_type: str, width, height):
49
+ if gallery_type == "user":
50
+ selected_lora = {"lora_name": "custom", "trigger_word": "custom"}
51
+ else:
52
+ selected_lora = lora_models[evt.index]
53
  new_placeholder = f"Ingresa un prompt para tu modelo {selected_lora['lora_name']}"
54
  trigger_word = selected_lora["trigger_word"]
55
  updated_text = f"#### Palabra clave: {trigger_word} ✨"
 
59
  width, height = 768, 1024
60
  elif selected_lora["aspect"] == "landscape":
61
  width, height = 1024, 768
 
 
62
 
63
+ return gr.update(placeholder=new_placeholder), updated_text, evt.index, width, height, gallery_type
64
+
65
+ def compress_and_train(request: gr.Request, files, model_name, trigger_word, train_steps, lora_rank, batch_size, learning_rate):
66
  if not files:
67
+ return "No hay imágenes. Sube algunas imágenes para poder entrenar."
68
+
69
+ user = request.session.get('user')
70
+ if not user:
71
+ raise gr.Error("User not authenticated. Please log in.")
72
 
73
+ user_id = user['id']
74
  # Create a directory in the user's home folder
75
  output_dir = os.path.expanduser("~/gradio_training_data")
76
  os.makedirs(output_dir, exist_ok=True)
 
88
 
89
  print(f'[INFO] Procesando {trigger_word}')
90
  # Now call the train_lora function with the zip file path
91
+ result = lora_pipeline(user_id,
92
+ zip_path,
93
  model_name,
94
  trigger_word=trigger_word,
95
  steps=train_steps,
 
98
  autocaption=True,
99
  learning_rate=learning_rate)
100
 
101
+ return gr.Info("Tu modelo esta entrenando, En unos 20 minutos estará listo para que lo pruebes en 'Generación'.")
102
 
103
+ def run_lora(request: gr.Request, prompt, cfg_scale, steps, selected_index, selected_gallery, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
104
  user = request.session.get('user')
105
  if not user:
106
  raise gr.Error("User not authenticated. Please log in.")
107
+ lora_models = get_user_lora_models(user['id'])
108
+ print(f'Selected gallery: {selected_gallery}')
109
+ if selected_gallery == "user":
110
+ lora_models = get_user_lora_models(user['id'])
111
+ print('Using user models')
112
+ else: # public
113
+ lora_models = get_lora_models_info()
114
+ print('Using public models')
115
+ print(f'Selected index: {selected_index}')
116
+ if selected_index is None:
117
+ selected_lora = None
118
+ else:
119
+ selected_lora = lora_models[selected_index]
120
+
121
  generation_credits, _ = get_user_credits(user['id'])
122
+ if selected_lora:
123
+ print(f"Selected Lora: {selected_lora['lora_name']}")
124
+ model_name = selected_lora['lora_name']
125
+ use_default = False
126
+ else:
127
+ model_name = "black-forest-labs/flux-pro"
128
+ print(f"Using default Lora: {model_name}")
129
+ use_default = True
130
  if generation_credits <= 0:
131
  raise gr.Error("Ya no tienes creditos disponibles. Compra para continuar.")
132
 
133
+ image_url = generate_image(model_name, prompt, steps, cfg_scale, width, height, lora_scale, progress, use_default)
134
  image = url_to_pil_image(image_url)
135
 
136
  # Update user's credits
 
230
  columns=3,
231
  elem_id="gallery"
232
  )
233
+ gallery_type = gr.State("Public")
234
 
235
  with gr.Accordion("Configuracion Avanzada", open=False):
236
  with gr.Row():
 
245
 
246
  gallery.select(
247
  update_selection,
248
+ inputs=[gr.State("public"), width, height],
249
+ outputs=[prompt, selected_info, selected_index, width, height, gallery_type]
250
+ )
251
+
252
+ user_model_gallery.select(
253
+ update_selection,
254
+ inputs=[gr.State("user"), width, height],
255
+ outputs=[prompt, selected_info, selected_index, width, height, gallery_type]
256
  )
 
257
  gr.on(
258
  triggers=[generate_button.click, prompt.submit],
259
  fn=run_lora,
260
+ inputs=[prompt, cfg_scale, steps, selected_index, gallery_type, width, height, lora_scale],
261
  outputs=[result, generation_credits_display]
262
  )
263
 
 
279
  learning_rate = gr.Number(label='learning_rate', value=0.0004)
280
  training_status = gr.Textbox(label="Training Status")
281
 
 
 
282
  train_button.click(
283
  compress_and_train,
284
  inputs=[train_dataset, model_name, trigger_word, train_steps, lora_rank, batch_size, learning_rate],
 
289
  #main_demo.load(greet, None, title)
290
  #main_demo.load(greet, None, greetings)
291
  #main_demo.load((greet, display_credits), None, [greetings, generation_credits_display, train_credits_display])
292
+ main_demo.load(load_user_models, None, user_model_gallery)
293
  main_demo.load(load_greet_and_credits, None, [greetings, generation_credits_display, train_credits_display])
294
 
295
 
296
 
297
  # TODO:
298
  '''
299
+ - resolver mostrar bien los nombres de los modelos en la galeria
 
300
  - Training con creditos.
301
  - Stripe(?)
302
  - Mejorar boton de login/logout
services/image_generation.py CHANGED
@@ -3,19 +3,32 @@ from PIL import Image
3
  import requests
4
  from io import BytesIO
5
 
6
- #model_custom_test = "josebenitezg/flux-dev-ruth-estilo-1:c7ff81b58007c7cee3f69416e1e999192dafd8d1b1f269ea6cae137f04b34172"
7
- flux_pro = "black-forest-labs/flux-pro"
8
- def generate_image(prompt, steps, cfg_scale, width, height, lora_scale, progress, trigger_word='hi'):
9
  print(f"Generating image for prompt: {prompt}")
10
- img_url = replicate.run(
11
- flux_pro,
12
- input={
13
- "steps": steps,
14
- "prompt": prompt,
15
- "guidance": cfg_scale,
16
- "interval": 2,
17
- "aspect_ratio": "1:1",
18
- "safety_tolerance": 2
19
- }
20
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  return img_url
 
3
  import requests
4
  from io import BytesIO
5
 
6
+
7
+ def generate_image(model_name, prompt, steps, cfg_scale, width, height, lora_scale, progress, use_default=False, trigger_word='hi'):
 
8
  print(f"Generating image for prompt: {prompt}")
9
+ if use_default:
10
+ img_url = replicate.run(
11
+ "black-forest-labs/flux-pro",
12
+ input={
13
+ "steps": steps,
14
+ "prompt": prompt,
15
+ "guidance": cfg_scale,
16
+ "interval": 2,
17
+ "aspect_ratio": "1:1",
18
+ "safety_tolerance": 2
19
+ }
20
+ )
21
+ else:
22
+ img_url = replicate.run(
23
+ model_name,
24
+ input={
25
+ "model": "dev",
26
+ "steps": steps,
27
+ "prompt": prompt,
28
+ "guidance": cfg_scale,
29
+ "interval": 2,
30
+ "aspect_ratio": "1:1",
31
+ "safety_tolerance": 2
32
+ }
33
+ )
34
  return img_url
services/train_lora.py CHANGED
@@ -1,20 +1,22 @@
1
  import replicate
2
  import os
3
  from huggingface_hub import create_repo
 
4
 
5
  REPLICATE_OWNER = "josebenitezg"
6
 
7
- def lora_pipeline(zip_path, model_name, trigger_word="TOK", steps=1000, lora_rank=16, batch_size=1, autocaption=True, learning_rate=0.0004):
8
  print(f'Creating dataset for {model_name}')
9
- repo_name = f"joselobenitezg/flux-dev-{model_name}"
10
- create_repo(repo_name, repo_type='model')
 
11
 
12
  lora_name = f"flux-dev-{model_name}"
13
 
14
  model = replicate.models.create(
15
  owner=REPLICATE_OWNER,
16
  name=lora_name,
17
- visibility="public", # or "private" if you prefer
18
  hardware="gpu-t4", # Replicate will override this for fine-tuned models
19
  description="A fine-tuned FLUX.1 model"
20
  )
@@ -37,10 +39,14 @@ def lora_pipeline(zip_path, model_name, trigger_word="TOK", steps=1000, lora_ran
37
  "trigger_word": trigger_word,
38
  "learning_rate": learning_rate,
39
  "hf_token": os.getenv('HF_TOKEN'), # optional
40
- "hf_repo_id": repo_name, # optional
41
  },
42
  destination=f"{model.owner}/{model.name}"
43
  )
44
 
 
45
  print(f"Training started: {training.status}")
46
  print(f"Training URL: https://replicate.com/p/{training.id}")
 
 
 
 
1
  import replicate
2
  import os
3
  from huggingface_hub import create_repo
4
+ from database import create_lora_models
5
 
6
  REPLICATE_OWNER = "josebenitezg"
7
 
8
+ def lora_pipeline(user_id, zip_path, model_name, trigger_word="TOK", steps=1000, lora_rank=16, batch_size=1, autocaption=True, learning_rate=0.0004):
9
  print(f'Creating dataset for {model_name}')
10
+ hf_repo_name = f"joselobenitezg/flux-dev-{model_name}"
11
+ replicate_repo_name = f"josebenitezg/flux-dev-{model_name}"
12
+ create_repo(hf_repo_name, repo_type='model')
13
 
14
  lora_name = f"flux-dev-{model_name}"
15
 
16
  model = replicate.models.create(
17
  owner=REPLICATE_OWNER,
18
  name=lora_name,
19
+ visibility="private", # or "private" if you prefer
20
  hardware="gpu-t4", # Replicate will override this for fine-tuned models
21
  description="A fine-tuned FLUX.1 model"
22
  )
 
39
  "trigger_word": trigger_word,
40
  "learning_rate": learning_rate,
41
  "hf_token": os.getenv('HF_TOKEN'), # optional
42
+ "hf_repo_id": hf_repo_name, # optional
43
  },
44
  destination=f"{model.owner}/{model.name}"
45
  )
46
 
47
+ print(f"training: {training.keys()}")
48
  print(f"Training started: {training.status}")
49
  print(f"Training URL: https://replicate.com/p/{training.id}")
50
+ print(f"Creating model in Database")
51
+ training_url = f"https://replicate.com/p/{training.id}"
52
+ create_lora_models(user_id, replicate_repo_name, trigger_word, steps, lora_rank, batch_size, learning_rate, hf_repo_name, training_url)