Spaces:
Running on CPU Upgrade

zijun/sdxl-gke-demo

#1546
by zzjzz - opened
Files changed (1) hide show
  1. app.py +19 -16
app.py CHANGED
@@ -5,6 +5,7 @@ from datasets import load_dataset
5
  import base64
6
  import re
7
  import os
 
8
  import requests
9
  import time
10
  from PIL import Image
@@ -88,10 +89,11 @@ def infer(prompt, negative="low_quality", scale=7, style_name=None, profile: gr.
88
  if re.search(rf"\b{filter}\b", prompt):
89
  raise gr.Error("Please try again with a different prompt")
90
 
 
91
  prompt, negative = apply_style(style_name, prompt, negative)
92
  images = []
93
  url = os.getenv('JAX_BACKEND_URL')
94
- payload = {'prompt': prompt, 'negative_prompt': negative, 'guidance_scale': scale}
95
  start_time = time.time()
96
  images_request = requests.post(url, json = payload)
97
  print(time.time() - start_time)
@@ -100,22 +102,23 @@ def infer(prompt, negative="low_quality", scale=7, style_name=None, profile: gr.
100
  except requests.exceptions.JSONDecodeError:
101
  raise gr.Error("SDXL did not return a valid result, try again")
102
 
103
- for image in json_data["images"]:
104
- image_b64 = (f"data:image/jpeg;base64,{image}")
105
- images.append(image_b64)
 
106
 
107
- if profile is not None: # avoid conversion on non-logged-in users
108
- pil_image = Image.open(BytesIO(base64.b64decode(image)))
109
- user_history.save_image( # save images + metadata to user history
110
- label=prompt,
111
- image=pil_image,
112
- profile=profile,
113
- metadata={
114
- "prompt": prompt,
115
- "negative_prompt": negative,
116
- "guidance_scale": scale,
117
- },
118
- )
119
 
120
  return images, gr.update(visible=True)
121
 
 
5
  import base64
6
  import re
7
  import os
8
+ import random
9
  import requests
10
  import time
11
  from PIL import Image
 
89
  if re.search(rf"\b{filter}\b", prompt):
90
  raise gr.Error("Please try again with a different prompt")
91
 
92
+ seed = random.randint(0,4294967295)
93
  prompt, negative = apply_style(style_name, prompt, negative)
94
  images = []
95
  url = os.getenv('JAX_BACKEND_URL')
96
+ payload = {'instances': [{ 'prompt': prompt, 'negative_prompt': negative, 'parameters':{ 'guidance_scale': scale, 'seed': seed } }] }
97
  start_time = time.time()
98
  images_request = requests.post(url, json = payload)
99
  print(time.time() - start_time)
 
102
  except requests.exceptions.JSONDecodeError:
103
  raise gr.Error("SDXL did not return a valid result, try again")
104
 
105
+ for prediction in json_data["predictions"]:
106
+ for image in prediction["images"]:
107
+ image_b64 = (f"data:image/jpeg;base64,{image}")
108
+ images.append(image_b64)
109
 
110
+ if profile is not None: # avoid conversion on non-logged-in users
111
+ pil_image = Image.open(BytesIO(base64.b64decode(image)))
112
+ user_history.save_image( # save images + metadata to user history
113
+ label=prompt,
114
+ image=pil_image,
115
+ profile=profile,
116
+ metadata={
117
+ "prompt": prompt,
118
+ "negative_prompt": negative,
119
+ "guidance_scale": scale,
120
+ },
121
+ )
122
 
123
  return images, gr.update(visible=True)
124