adamelliotfields commited on
Commit
39d4ddb
·
verified ·
1 Parent(s): 11aa54f

Add Hugging Face txt2img client

Browse files
Files changed (4) hide show
  1. lib/__init__.py +2 -1
  2. lib/api.py +39 -0
  3. lib/config.py +3 -3
  4. pages/2_🎨_Text_to_Image.py +26 -25
lib/__init__.py CHANGED
@@ -1,4 +1,5 @@
 
1
  from .config import Config
2
  from .presets import Presets
3
 
4
- __all__ = ["Config", "Presets"]
 
1
+ from .api import HuggingFaceTxt2ImgAPI
2
  from .config import Config
3
  from .presets import Presets
4
 
5
+ __all__ = ["Config", "HuggingFaceTxt2ImgAPI", "Presets"]
lib/api.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from abc import ABC, abstractmethod
3
+
4
+ import requests
5
+ from PIL import Image
6
+
7
+
8
+ class Txt2ImgAPI(ABC):
9
+ @abstractmethod
10
+ def generate_image(self, model, prompt, parameters, **kwargs):
11
+ pass
12
+
13
+
14
+ # essentially the same as huggingface_hub's inference client
15
+ class HuggingFaceTxt2ImgAPI(Txt2ImgAPI):
16
+ def __init__(self, token):
17
+ self.api_url = "https://api-inference.huggingface.co/models"
18
+ self.headers = {
19
+ "Authorization": f"Bearer {token}",
20
+ "X-Wait-For-Model": "true",
21
+ "X-Use-Cache": "false",
22
+ }
23
+
24
+ def generate_image(self, model, prompt, parameters, **kwargs):
25
+ try:
26
+ response = requests.post(
27
+ f"{self.api_url}/{model}",
28
+ headers=self.headers,
29
+ json={
30
+ "inputs": prompt,
31
+ "parameters": {**parameters, **kwargs},
32
+ },
33
+ )
34
+ if response.status_code == 200:
35
+ return Image.open(io.BytesIO(response.content))
36
+ else:
37
+ raise Exception(f"Error: {response.status_code} - {response.text}")
38
+ except Exception as e:
39
+ return str(e)
lib/config.py CHANGED
@@ -4,7 +4,7 @@ Config = SimpleNamespace(
4
  TITLE="API Inference",
5
  ICON="⚡",
6
  LAYOUT="wide",
7
- TXT2IMG_NEGATIVE_PROMPT="ugly, bad, asymmetrical, malformed, mutated, disgusting, blurry, grainy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, autograph, trademark",
8
  TXT2IMG_DEFAULT_MODEL=2,
9
  TXT2IMG_MODELS=[
10
  "black-forest-labs/flux.1-dev",
@@ -13,11 +13,11 @@ Config = SimpleNamespace(
13
  ],
14
  TXT2IMG_DEFAULT_AR="1:1",
15
  TXT2IMG_AR={
16
- "9:7": (1152, 896),
17
  "7:4": (1344, 768),
 
18
  "1:1": (1024, 1024),
19
- "4:7": (768, 1344),
20
  "7:9": (896, 1152),
 
21
  },
22
  TXT2TXT_DEFAULT_MODEL=4,
23
  TXT2TXT_MODELS=[
 
4
  TITLE="API Inference",
5
  ICON="⚡",
6
  LAYOUT="wide",
7
+ TXT2IMG_NEGATIVE_PROMPT="ugly, unattractive, malformed, mutated, disgusting, blurry, grainy, noisy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, autograph, trademark, sloppy, cluttered",
8
  TXT2IMG_DEFAULT_MODEL=2,
9
  TXT2IMG_MODELS=[
10
  "black-forest-labs/flux.1-dev",
 
13
  ],
14
  TXT2IMG_DEFAULT_AR="1:1",
15
  TXT2IMG_AR={
 
16
  "7:4": (1344, 768),
17
+ "9:7": (1152, 896),
18
  "1:1": (1024, 1024),
 
19
  "7:9": (896, 1152),
20
+ "4:7": (768, 1344),
21
  },
22
  TXT2TXT_DEFAULT_MODEL=4,
23
  TXT2TXT_MODELS=[
pages/2_🎨_Text_to_Image.py CHANGED
@@ -1,12 +1,9 @@
1
- import io
2
  import os
3
  from datetime import datetime
4
 
5
- import requests
6
  import streamlit as st
7
- from PIL import Image
8
 
9
- from lib import Config, Presets
10
 
11
  # TODO: key input and store in cache_data
12
  # TODO: API dropdown; changes available models
@@ -20,22 +17,9 @@ PRESET_MODEL = {
20
  }
21
 
22
 
23
- def generate_image(model, prompt, parameters, **kwargs):
24
- response = requests.post(
25
- f"{API_URL}/{model}",
26
- headers=HEADERS,
27
- json={
28
- "inputs": prompt,
29
- "parameters": {**parameters, **kwargs},
30
- },
31
- )
32
-
33
- if response.status_code == 200:
34
- image = Image.open(io.BytesIO(response.content))
35
- return image
36
- else:
37
- st.error(f"Error: {response.status_code} - {response.text}")
38
- return None
39
 
40
 
41
  # config
@@ -46,6 +30,9 @@ st.set_page_config(
46
  )
47
 
48
  # initialize state
 
 
 
49
  if "txt2img_messages" not in st.session_state:
50
  st.session_state.txt2img_messages = []
51
 
@@ -60,11 +47,13 @@ model = st.sidebar.selectbox(
60
  format_func=lambda x: x.split("/")[1],
61
  options=Config.TXT2IMG_MODELS,
62
  index=Config.TXT2IMG_DEFAULT_MODEL,
 
63
  )
64
  aspect_ratio = st.sidebar.select_slider(
65
  "Aspect Ratio",
66
  options=list(Config.TXT2IMG_AR.keys()),
67
  value=Config.TXT2IMG_DEFAULT_AR,
 
68
  )
69
 
70
  # heading
@@ -88,6 +77,7 @@ for param in preset["parameters"]:
88
  preset["guidance_scale_max"],
89
  preset["guidance_scale"],
90
  0.1,
 
91
  )
92
  if param == "num_inference_steps":
93
  parameters[param] = st.sidebar.slider(
@@ -96,6 +86,7 @@ for param in preset["parameters"]:
96
  preset["num_inference_steps_max"],
97
  preset["num_inference_steps"],
98
  1,
 
99
  )
100
  if param == "seed":
101
  parameters[param] = st.sidebar.number_input(
@@ -103,11 +94,13 @@ for param in preset["parameters"]:
103
  min_value=-1,
104
  max_value=(1 << 53) - 1,
105
  value=-1,
 
106
  )
107
  if param == "negative_prompt":
108
  parameters[param] = st.sidebar.text_area(
109
  label="Negative Prompt",
110
  value=Config.TXT2IMG_NEGATIVE_PROMPT,
 
111
  )
112
 
113
  # wrap the prompt in an expander to display additional parameters
@@ -142,7 +135,7 @@ for message in st.session_state.txt2img_messages:
142
  }
143
  </style>
144
  """)
145
- st.image(message["content"])
146
 
147
  # button row
148
  if st.session_state.txt2img_messages:
@@ -162,13 +155,16 @@ if st.session_state.txt2img_messages:
162
  # retry
163
  col1, col2 = st.columns(2)
164
  with col1:
165
- if st.button("❌", help="Delete last generation") and len(st.session_state.txt2img_messages) >= 2:
 
 
 
166
  st.session_state.txt2img_messages.pop()
167
  st.session_state.txt2img_messages.pop()
168
  st.rerun()
169
 
170
  with col2:
171
- if st.button("🗑️", help="Clear all generations"):
172
  st.session_state.txt2img_messages = []
173
  st.session_state.txt2img_seed = 0
174
  st.rerun()
@@ -176,7 +172,10 @@ else:
176
  button_container = None
177
 
178
  # show the prompt and spinner while loading then update state and re-render
179
- if prompt := st.chat_input("What do you want to see?"):
 
 
 
180
  if "seed" in parameters and parameters["seed"] >= 0:
181
  st.session_state.txt2img_seed = parameters["seed"]
182
  else:
@@ -195,7 +194,9 @@ if prompt := st.chat_input("What do you want to see?"):
195
  generate_kwargs = {"model": model, "prompt": prompt, "parameters": parameters}
196
  if preset.get("kwargs") is not None:
197
  generate_kwargs.update(preset["kwargs"])
198
- image = generate_image(**generate_kwargs)
 
 
199
 
200
  model_name = PRESET_MODEL[model]["name"]
201
  st.session_state.txt2img_messages.append(
 
 
1
  import os
2
  from datetime import datetime
3
 
 
4
  import streamlit as st
 
5
 
6
+ from lib import Config, HuggingFaceTxt2ImgAPI, Presets
7
 
8
  # TODO: key input and store in cache_data
9
  # TODO: API dropdown; changes available models
 
17
  }
18
 
19
 
20
+ @st.cache_resource
21
+ def get_txt2img_api():
22
+ return HuggingFaceTxt2ImgAPI(HF_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  # config
 
30
  )
31
 
32
  # initialize state
33
+ if "txt2img_running" not in st.session_state:
34
+ st.session_state.txt2img_running = False
35
+
36
  if "txt2img_messages" not in st.session_state:
37
  st.session_state.txt2img_messages = []
38
 
 
47
  format_func=lambda x: x.split("/")[1],
48
  options=Config.TXT2IMG_MODELS,
49
  index=Config.TXT2IMG_DEFAULT_MODEL,
50
+ disabled=st.session_state.txt2img_running,
51
  )
52
  aspect_ratio = st.sidebar.select_slider(
53
  "Aspect Ratio",
54
  options=list(Config.TXT2IMG_AR.keys()),
55
  value=Config.TXT2IMG_DEFAULT_AR,
56
+ disabled=st.session_state.txt2img_running,
57
  )
58
 
59
  # heading
 
77
  preset["guidance_scale_max"],
78
  preset["guidance_scale"],
79
  0.1,
80
+ disabled=st.session_state.txt2img_running,
81
  )
82
  if param == "num_inference_steps":
83
  parameters[param] = st.sidebar.slider(
 
86
  preset["num_inference_steps_max"],
87
  preset["num_inference_steps"],
88
  1,
89
+ disabled=st.session_state.txt2img_running,
90
  )
91
  if param == "seed":
92
  parameters[param] = st.sidebar.number_input(
 
94
  min_value=-1,
95
  max_value=(1 << 53) - 1,
96
  value=-1,
97
+ disabled=st.session_state.txt2img_running,
98
  )
99
  if param == "negative_prompt":
100
  parameters[param] = st.sidebar.text_area(
101
  label="Negative Prompt",
102
  value=Config.TXT2IMG_NEGATIVE_PROMPT,
103
+ disabled=st.session_state.txt2img_running,
104
  )
105
 
106
  # wrap the prompt in an expander to display additional parameters
 
135
  }
136
  </style>
137
  """)
138
+ st.write(message["content"]) # success will be image, error will be text
139
 
140
  # button row
141
  if st.session_state.txt2img_messages:
 
155
  # retry
156
  col1, col2 = st.columns(2)
157
  with col1:
158
+ if (
159
+ st.button("❌", help="Delete last generation", disabled=st.session_state.txt2img_running)
160
+ and len(st.session_state.txt2img_messages) >= 2
161
+ ):
162
  st.session_state.txt2img_messages.pop()
163
  st.session_state.txt2img_messages.pop()
164
  st.rerun()
165
 
166
  with col2:
167
+ if st.button("🗑️", help="Clear all generations", disabled=st.session_state.txt2img_running):
168
  st.session_state.txt2img_messages = []
169
  st.session_state.txt2img_seed = 0
170
  st.rerun()
 
172
  button_container = None
173
 
174
  # show the prompt and spinner while loading then update state and re-render
175
+ if prompt := st.chat_input(
176
+ "What do you want to see?",
177
+ on_submit=lambda: setattr(st.session_state, "txt2img_running", True),
178
+ ):
179
  if "seed" in parameters and parameters["seed"] >= 0:
180
  st.session_state.txt2img_seed = parameters["seed"]
181
  else:
 
194
  generate_kwargs = {"model": model, "prompt": prompt, "parameters": parameters}
195
  if preset.get("kwargs") is not None:
196
  generate_kwargs.update(preset["kwargs"])
197
+ api = get_txt2img_api()
198
+ image = api.generate_image(**generate_kwargs)
199
+ st.session_state.txt2img_running = False
200
 
201
  model_name = PRESET_MODEL[model]["name"]
202
  st.session_state.txt2img_messages.append(