luoxiang commited on
Commit
b88a2d3
·
1 Parent(s): 490592b
Files changed (2) hide show
  1. app.py +37 -36
  2. requirements.txt +7 -0
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
- #import torch
3
  #from torch import autocast
4
- #from diffusers import StableDiffusionPipeline
5
  from datasets import load_dataset
6
  from PIL import Image
7
  #from io import BytesIO
@@ -10,19 +10,20 @@ import re
10
  import os
11
  import requests
12
 
 
13
  from share_btn import community_icon_html, loading_icon_html, share_js
14
 
15
  model_id = "CompVis/stable-diffusion-v1-4"
16
- device = "cuda"
17
 
18
  #If you are running this code locally, you need to either do a 'huggingface-cli login` or paste your User Access Token from here https://huggingface.co/settings/tokens into the use_auth_token field below.
19
- #pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, revision="fp16", torch_dtype=torch.float16)
20
  #pipe = pipe.to(device)
21
- #torch.backends.cudnn.benchmark = True
22
 
23
- #When running locally, you won`t have access to this, so you can remove this part
24
- word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
25
- word_list = word_list_dataset["train"]['text']
26
 
27
  is_gpu_busy = False
28
  def infer(prompt):
@@ -30,38 +31,38 @@ def infer(prompt):
30
  samples = 4
31
  steps = 50
32
  scale = 7.5
33
- #When running locally you can also remove this filter
34
- for filter in word_list:
35
- if re.search(rf"\b{filter}\b", prompt):
36
- raise gr.Error("Unsafe content found. Please try again with different prompts.")
37
 
38
  #generator = torch.Generator(device=device).manual_seed(seed)
39
  #print("Is GPU busy? ", is_gpu_busy)
40
  images = []
41
- #if(not is_gpu_busy):
42
- # is_gpu_busy = True
43
- # images_list = pipe(
44
- # [prompt] * samples,
45
- # num_inference_steps=steps,
46
- # guidance_scale=scale,
47
  #generator=generator,
48
- # )
49
- # is_gpu_busy = False
50
- # safe_image = Image.open(r"unsafe.png")
51
- # for i, image in enumerate(images_list["sample"]):
52
  # if(images_list["nsfw_content_detected"][i]):
53
  # images.append(safe_image)
54
  # else:
55
- # images.append(image)
56
  #else:
57
- url = os.getenv('JAX_BACKEND_URL')
58
- payload = {'prompt': prompt}
59
- images_request = requests.post(url, json = payload)
60
- for image in images_request.json()["images"]:
61
- image_b64 = (f"data:image/png;base64,{image}")
62
- images.append(image_b64)
63
 
64
- return images, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
65
 
66
 
67
  css = """
@@ -312,9 +313,9 @@ with block:
312
  with gr.Group(elem_id="container-advanced-btns"):
313
  advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
314
  with gr.Group(elem_id="share-btn-container"):
315
- community_icon = gr.HTML(community_icon_html, visible=False)
316
- loading_icon = gr.HTML(loading_icon_html, visible=False)
317
- share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
318
 
319
  with gr.Row(elem_id="advanced-options"):
320
  gr.Markdown("Advanced settings are temporarily unavailable")
@@ -334,8 +335,8 @@ with block:
334
  ex = gr.Examples(examples=examples, fn=infer, inputs=text, outputs=[gallery, community_icon, loading_icon, share_button], cache_examples=False)
335
  ex.dataset.headers = [""]
336
 
337
- text.submit(infer, inputs=text, outputs=[gallery, community_icon, loading_icon, share_button], postprocess=False)
338
- btn.click(infer, inputs=text, outputs=[gallery, community_icon, loading_icon, share_button], postprocess=False)
339
 
340
  advanced_button.click(
341
  None,
@@ -368,4 +369,4 @@ Despite how impressive being able to turn text into image is, beware to the fact
368
  """
369
  )
370
 
371
- block.queue(max_size=5, concurrency_count=2).launch()
 
1
  import gradio as gr
2
+ import torch
3
  #from torch import autocast
4
+ from diffusers import StableDiffusionPipeline
5
  from datasets import load_dataset
6
  from PIL import Image
7
  #from io import BytesIO
 
10
  import os
11
  import requests
12
 
13
+
14
  from share_btn import community_icon_html, loading_icon_html, share_js
15
 
16
  model_id = "CompVis/stable-diffusion-v1-4"
17
+ #device = "cuda"
18
 
19
  #If you are running this code locally, you need to either do a 'huggingface-cli login` or paste your User Access Token from here https://huggingface.co/settings/tokens into the use_auth_token field below.
20
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=os.getenv('ACCESS_TOKEN'), torch_dtype=torch.float32)
21
  #pipe = pipe.to(device)
22
+ torch.backends.cudnn.benchmark = True
23
 
24
+ ##When running locally, you won`t have access to this, so you can remove this part
25
+ #word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
26
+ #word_list = word_list_dataset["train"]['text']
27
 
28
  is_gpu_busy = False
29
  def infer(prompt):
 
31
  samples = 4
32
  steps = 50
33
  scale = 7.5
34
+ ##When running locally you can also remove this filter
35
+ #for filter in word_list:
36
+ # if re.search(rf"\b{filter}\b", prompt):
37
+ # raise gr.Error("Unsafe content found. Please try again with different prompts.")
38
 
39
  #generator = torch.Generator(device=device).manual_seed(seed)
40
  #print("Is GPU busy? ", is_gpu_busy)
41
  images = []
42
+ if(not is_gpu_busy):
43
+ is_gpu_busy = True
44
+ images_list = pipe(
45
+ [prompt] * samples,
46
+ num_inference_steps=steps,
47
+ guidance_scale=scale,
48
  #generator=generator,
49
+ )
50
+ is_gpu_busy = False
51
+ #safe_image = Image.open(r"unsafe.png")
52
+ for i, image in enumerate(images_list["sample"]):
53
  # if(images_list["nsfw_content_detected"][i]):
54
  # images.append(safe_image)
55
  # else:
56
+ images.append(image)
57
  #else:
58
+ #url = os.getenv('JAX_BACKEND_URL')
59
+ #payload = {'prompt': prompt}
60
+ #images_request = requests.post(url, json = payload)
61
+ #for image in images_request.json()["images"]:
62
+ # image_b64 = (f"data:image/jpeg;base64,{image}")
63
+ # images.append(image_b64)
64
 
65
+ return images
66
 
67
 
68
  css = """
 
313
  with gr.Group(elem_id="container-advanced-btns"):
314
  advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
315
  with gr.Group(elem_id="share-btn-container"):
316
+ community_icon = gr.HTML(community_icon_html)
317
+ loading_icon = gr.HTML(loading_icon_html)
318
+ share_button = gr.Button("Share to community", elem_id="share-btn")
319
 
320
  with gr.Row(elem_id="advanced-options"):
321
  gr.Markdown("Advanced settings are temporarily unavailable")
 
335
  ex = gr.Examples(examples=examples, fn=infer, inputs=text, outputs=[gallery, community_icon, loading_icon, share_button], cache_examples=False)
336
  ex.dataset.headers = [""]
337
 
338
+ text.submit(infer, inputs=text, outputs=[gallery], postprocess=False)
339
+ btn.click(infer, inputs=text, outputs=[gallery], postprocess=False)
340
 
341
  advanced_button.click(
342
  None,
 
369
  """
370
  )
371
 
372
+ block.queue(max_size=50, concurrency_count=20).launch()
requirements.txt CHANGED
@@ -1,2 +1,9 @@
1
  python-dotenv
 
 
 
 
 
 
 
2
  https://gradio-builds.s3.amazonaws.com/queue-disconnect/v3/gradio-3.4b2-py3-none-any.whl
 
1
  python-dotenv
2
+ mkl
3
+ spacy
4
+ ftfy
5
+ torch
6
+ transformers
7
+ diffusers
8
+ torchvision
9
  https://gradio-builds.s3.amazonaws.com/queue-disconnect/v3/gradio-3.4b2-py3-none-any.whl