Jose Benitez commited on
Commit
fdd33ad
·
1 Parent(s): 7a6ca6b

add hf lora models support

Browse files
Files changed (1) hide show
  1. services/image_generation.py +20 -9
services/image_generation.py CHANGED
@@ -2,6 +2,7 @@ import replicate
2
  from PIL import Image
3
  import requests
4
  from io import BytesIO
 
5
 
6
 
7
  def generate_image(model_name, prompt, steps, cfg_scale, width, height, lora_scale, progress, use_default=False, trigger_word='hi'):
@@ -19,15 +20,7 @@ def generate_image(model_name, prompt, steps, cfg_scale, width, height, lora_sca
19
  }
20
  )
21
  else:
22
- # check if the model has a version. the model is something like model='user/model-name:version' but sometimes we just got model='user/model-name' in this case, let get and add the model version
23
- if ':' not in model_name:
24
- model_version = replicate.models.get(model_name).latest_version.id
25
- print(f"Model version: {model_version}")
26
- model_name = f"{model_name}:{model_version}"
27
-
28
- img_url = replicate.run(
29
- model_name,
30
- input={
31
  "model": "dev",
32
  "steps": steps,
33
  "prompt": prompt,
@@ -36,5 +29,23 @@ def generate_image(model_name, prompt, steps, cfg_scale, width, height, lora_sca
36
  "aspect_ratio": "1:1",
37
  "safety_tolerance": 2
38
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  )
40
  return img_url
 
 
2
  from PIL import Image
3
  import requests
4
  from io import BytesIO
5
+ from database import get_lora_models_info
6
 
7
 
8
  def generate_image(model_name, prompt, steps, cfg_scale, width, height, lora_scale, progress, use_default=False, trigger_word='hi'):
 
20
  }
21
  )
22
  else:
23
+ input = {
 
 
 
 
 
 
 
 
24
  "model": "dev",
25
  "steps": steps,
26
  "prompt": prompt,
 
29
  "aspect_ratio": "1:1",
30
  "safety_tolerance": 2
31
  }
32
+
33
+ db_loras = get_lora_models_info()
34
+
35
+ for lora in db_loras:
36
+ if lora["lora_name"] == model_name:
37
+ if lora["hf_repo"]:
38
+ input["hf_lora"] = lora["hf_repo"]
39
+ model_name = "lucataco/flux-dev-lora:a22c463f11808638ad5e2ebd582e07a469031f48dd567366fb4c6fdab91d614d"
40
+
41
+ if ':' not in model_name:
42
+ model_version = replicate.models.get(model_name).latest_version.id
43
+ print(f"Model version: {model_version}")
44
+ model_name = f"{model_name}:{model_version}"
45
+
46
+ img_url = replicate.run(
47
+ model_name,
48
+ input=input
49
  )
50
  return img_url
51
+