yasserrmd commited on
Commit
9388d07
·
verified ·
1 Parent(s): 19f0e18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -6
app.py CHANGED
@@ -9,8 +9,17 @@ device=torch.device('cuda')
9
 
10
  # Load the model and LoRA weights
11
  pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
12
- #pipe.load_lora_weights("Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch", weight_name="FLUX-dev-lora-children-simple-sketch.safetensors")
13
- pipe.load_lora_weights("Datou1111/shou_xin", weight_name="shou_xin.safetensors")
 
 
 
 
 
 
 
 
 
14
  pipe.fuse_lora(lora_scale=1.5)
15
  pipe.to("cuda")
16
 
@@ -22,7 +31,7 @@ NSFW_THRESHOLD = 0.3
22
 
23
  # Define the function to generate the sketch
24
  @spaces.GPU
25
- def generate_sketch(prompt, num_inference_steps, guidance_scale):
26
  # Classify the text for NSFW content
27
  #text_classification = text_classifier(prompt)
28
  #print(text_classification)
@@ -31,9 +40,8 @@ def generate_sketch(prompt, num_inference_steps, guidance_scale):
31
  #for result in text_classification:
32
  # if result['label'] == 'nsfw' and result['score'] > NSFW_THRESHOLD:
33
  # return gr.update(visible=False),gr.Text(value="Inappropriate prompt detected. Please try another prompt.")
34
- print(prompt)
35
- prompt= "sketched style, " + prompt
36
- prompt= "shou_xin, " + prompt
37
  image = pipe("sketched style, " + prompt,
38
  num_inference_steps=num_inference_steps,
39
  guidance_scale=guidance_scale,
@@ -60,6 +68,9 @@ interface = gr.Interface(
60
  fn=generate_sketch,
61
  inputs=[
62
  "text", # Prompt input
 
 
 
63
  gr.Slider(5, 50, value=24, step=1, label="Number of Inference Steps"), # Slider for num_inference_steps
64
  gr.Slider(1.0, 10.0, value=3.5, step=0.1, label="Guidance Scale") # Slider for guidance_scale
65
  ],
 
9
 
10
  # Load the model and LoRA weights
11
  pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
12
+ print(prompt)
13
+
14
+ if style='shou_xin':
15
+ prompt= "shou_xin, " + prompt
16
+ pipe.load_lora_weights("Datou1111/shou_xin", weight_name="shou_xin.safetensors")
17
+ else:
18
+ prompt= "sketched style, " + prompt
19
+ pipe.load_lora_weights("Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch", weight_name="FLUX-dev-lora-children-simple-sketch.safetensors")
20
+
21
+
22
+
23
  pipe.fuse_lora(lora_scale=1.5)
24
  pipe.to("cuda")
25
 
 
31
 
32
  # Define the function to generate the sketch
33
  @spaces.GPU
34
+ def generate_sketch(prompt,style, num_inference_steps, guidance_scale):
35
  # Classify the text for NSFW content
36
  #text_classification = text_classifier(prompt)
37
  #print(text_classification)
 
40
  #for result in text_classification:
41
  # if result['label'] == 'nsfw' and result['score'] > NSFW_THRESHOLD:
42
  # return gr.update(visible=False),gr.Text(value="Inappropriate prompt detected. Please try another prompt.")
43
+
44
+
 
45
  image = pipe("sketched style, " + prompt,
46
  num_inference_steps=num_inference_steps,
47
  guidance_scale=guidance_scale,
 
68
  fn=generate_sketch,
69
  inputs=[
70
  "text", # Prompt input
71
+ gr.Dropdown(
72
+ ["sketched", "shou_xin"], label="Style"
73
+ ),
74
  gr.Slider(5, 50, value=24, step=1, label="Number of Inference Steps"), # Slider for num_inference_steps
75
  gr.Slider(1.0, 10.0, value=3.5, step=0.1, label="Guidance Scale") # Slider for guidance_scale
76
  ],