adamelliotfields
commited on
Add Hugging Face txt2img client
Browse files- lib/__init__.py +2 -1
- lib/api.py +39 -0
- lib/config.py +3 -3
- pages/2_🎨_Text_to_Image.py +26 -25
lib/__init__.py
CHANGED
@@ -1,4 +1,5 @@
|
|
|
|
1 |
from .config import Config
|
2 |
from .presets import Presets
|
3 |
|
4 |
-
__all__ = ["Config", "Presets"]
|
|
|
1 |
+
from .api import HuggingFaceTxt2ImgAPI
|
2 |
from .config import Config
|
3 |
from .presets import Presets
|
4 |
|
5 |
+
__all__ = ["Config", "HuggingFaceTxt2ImgAPI", "Presets"]
|
lib/api.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
|
4 |
+
import requests
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
class Txt2ImgAPI(ABC):
|
9 |
+
@abstractmethod
|
10 |
+
def generate_image(self, model, prompt, parameters, **kwargs):
|
11 |
+
pass
|
12 |
+
|
13 |
+
|
14 |
+
# essentially the same as huggingface_hub's inference client
|
15 |
+
class HuggingFaceTxt2ImgAPI(Txt2ImgAPI):
|
16 |
+
def __init__(self, token):
|
17 |
+
self.api_url = "https://api-inference.huggingface.co/models"
|
18 |
+
self.headers = {
|
19 |
+
"Authorization": f"Bearer {token}",
|
20 |
+
"X-Wait-For-Model": "true",
|
21 |
+
"X-Use-Cache": "false",
|
22 |
+
}
|
23 |
+
|
24 |
+
def generate_image(self, model, prompt, parameters, **kwargs):
|
25 |
+
try:
|
26 |
+
response = requests.post(
|
27 |
+
f"{self.api_url}/{model}",
|
28 |
+
headers=self.headers,
|
29 |
+
json={
|
30 |
+
"inputs": prompt,
|
31 |
+
"parameters": {**parameters, **kwargs},
|
32 |
+
},
|
33 |
+
)
|
34 |
+
if response.status_code == 200:
|
35 |
+
return Image.open(io.BytesIO(response.content))
|
36 |
+
else:
|
37 |
+
raise Exception(f"Error: {response.status_code} - {response.text}")
|
38 |
+
except Exception as e:
|
39 |
+
return str(e)
|
lib/config.py
CHANGED
@@ -4,7 +4,7 @@ Config = SimpleNamespace(
|
|
4 |
TITLE="API Inference",
|
5 |
ICON="⚡",
|
6 |
LAYOUT="wide",
|
7 |
-
TXT2IMG_NEGATIVE_PROMPT="ugly,
|
8 |
TXT2IMG_DEFAULT_MODEL=2,
|
9 |
TXT2IMG_MODELS=[
|
10 |
"black-forest-labs/flux.1-dev",
|
@@ -13,11 +13,11 @@ Config = SimpleNamespace(
|
|
13 |
],
|
14 |
TXT2IMG_DEFAULT_AR="1:1",
|
15 |
TXT2IMG_AR={
|
16 |
-
"9:7": (1152, 896),
|
17 |
"7:4": (1344, 768),
|
|
|
18 |
"1:1": (1024, 1024),
|
19 |
-
"4:7": (768, 1344),
|
20 |
"7:9": (896, 1152),
|
|
|
21 |
},
|
22 |
TXT2TXT_DEFAULT_MODEL=4,
|
23 |
TXT2TXT_MODELS=[
|
|
|
4 |
TITLE="API Inference",
|
5 |
ICON="⚡",
|
6 |
LAYOUT="wide",
|
7 |
+
TXT2IMG_NEGATIVE_PROMPT="ugly, unattractive, malformed, mutated, disgusting, blurry, grainy, noisy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, autograph, trademark, sloppy, cluttered",
|
8 |
TXT2IMG_DEFAULT_MODEL=2,
|
9 |
TXT2IMG_MODELS=[
|
10 |
"black-forest-labs/flux.1-dev",
|
|
|
13 |
],
|
14 |
TXT2IMG_DEFAULT_AR="1:1",
|
15 |
TXT2IMG_AR={
|
|
|
16 |
"7:4": (1344, 768),
|
17 |
+
"9:7": (1152, 896),
|
18 |
"1:1": (1024, 1024),
|
|
|
19 |
"7:9": (896, 1152),
|
20 |
+
"4:7": (768, 1344),
|
21 |
},
|
22 |
TXT2TXT_DEFAULT_MODEL=4,
|
23 |
TXT2TXT_MODELS=[
|
pages/2_🎨_Text_to_Image.py
CHANGED
@@ -1,12 +1,9 @@
|
|
1 |
-
import io
|
2 |
import os
|
3 |
from datetime import datetime
|
4 |
|
5 |
-
import requests
|
6 |
import streamlit as st
|
7 |
-
from PIL import Image
|
8 |
|
9 |
-
from lib import Config, Presets
|
10 |
|
11 |
# TODO: key input and store in cache_data
|
12 |
# TODO: API dropdown; changes available models
|
@@ -20,22 +17,9 @@ PRESET_MODEL = {
|
|
20 |
}
|
21 |
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
headers=HEADERS,
|
27 |
-
json={
|
28 |
-
"inputs": prompt,
|
29 |
-
"parameters": {**parameters, **kwargs},
|
30 |
-
},
|
31 |
-
)
|
32 |
-
|
33 |
-
if response.status_code == 200:
|
34 |
-
image = Image.open(io.BytesIO(response.content))
|
35 |
-
return image
|
36 |
-
else:
|
37 |
-
st.error(f"Error: {response.status_code} - {response.text}")
|
38 |
-
return None
|
39 |
|
40 |
|
41 |
# config
|
@@ -46,6 +30,9 @@ st.set_page_config(
|
|
46 |
)
|
47 |
|
48 |
# initialize state
|
|
|
|
|
|
|
49 |
if "txt2img_messages" not in st.session_state:
|
50 |
st.session_state.txt2img_messages = []
|
51 |
|
@@ -60,11 +47,13 @@ model = st.sidebar.selectbox(
|
|
60 |
format_func=lambda x: x.split("/")[1],
|
61 |
options=Config.TXT2IMG_MODELS,
|
62 |
index=Config.TXT2IMG_DEFAULT_MODEL,
|
|
|
63 |
)
|
64 |
aspect_ratio = st.sidebar.select_slider(
|
65 |
"Aspect Ratio",
|
66 |
options=list(Config.TXT2IMG_AR.keys()),
|
67 |
value=Config.TXT2IMG_DEFAULT_AR,
|
|
|
68 |
)
|
69 |
|
70 |
# heading
|
@@ -88,6 +77,7 @@ for param in preset["parameters"]:
|
|
88 |
preset["guidance_scale_max"],
|
89 |
preset["guidance_scale"],
|
90 |
0.1,
|
|
|
91 |
)
|
92 |
if param == "num_inference_steps":
|
93 |
parameters[param] = st.sidebar.slider(
|
@@ -96,6 +86,7 @@ for param in preset["parameters"]:
|
|
96 |
preset["num_inference_steps_max"],
|
97 |
preset["num_inference_steps"],
|
98 |
1,
|
|
|
99 |
)
|
100 |
if param == "seed":
|
101 |
parameters[param] = st.sidebar.number_input(
|
@@ -103,11 +94,13 @@ for param in preset["parameters"]:
|
|
103 |
min_value=-1,
|
104 |
max_value=(1 << 53) - 1,
|
105 |
value=-1,
|
|
|
106 |
)
|
107 |
if param == "negative_prompt":
|
108 |
parameters[param] = st.sidebar.text_area(
|
109 |
label="Negative Prompt",
|
110 |
value=Config.TXT2IMG_NEGATIVE_PROMPT,
|
|
|
111 |
)
|
112 |
|
113 |
# wrap the prompt in an expander to display additional parameters
|
@@ -142,7 +135,7 @@ for message in st.session_state.txt2img_messages:
|
|
142 |
}
|
143 |
</style>
|
144 |
""")
|
145 |
-
st.
|
146 |
|
147 |
# button row
|
148 |
if st.session_state.txt2img_messages:
|
@@ -162,13 +155,16 @@ if st.session_state.txt2img_messages:
|
|
162 |
# retry
|
163 |
col1, col2 = st.columns(2)
|
164 |
with col1:
|
165 |
-
if
|
|
|
|
|
|
|
166 |
st.session_state.txt2img_messages.pop()
|
167 |
st.session_state.txt2img_messages.pop()
|
168 |
st.rerun()
|
169 |
|
170 |
with col2:
|
171 |
-
if st.button("🗑️", help="Clear all generations"):
|
172 |
st.session_state.txt2img_messages = []
|
173 |
st.session_state.txt2img_seed = 0
|
174 |
st.rerun()
|
@@ -176,7 +172,10 @@ else:
|
|
176 |
button_container = None
|
177 |
|
178 |
# show the prompt and spinner while loading then update state and re-render
|
179 |
-
if prompt := st.chat_input(
|
|
|
|
|
|
|
180 |
if "seed" in parameters and parameters["seed"] >= 0:
|
181 |
st.session_state.txt2img_seed = parameters["seed"]
|
182 |
else:
|
@@ -195,7 +194,9 @@ if prompt := st.chat_input("What do you want to see?"):
|
|
195 |
generate_kwargs = {"model": model, "prompt": prompt, "parameters": parameters}
|
196 |
if preset.get("kwargs") is not None:
|
197 |
generate_kwargs.update(preset["kwargs"])
|
198 |
-
|
|
|
|
|
199 |
|
200 |
model_name = PRESET_MODEL[model]["name"]
|
201 |
st.session_state.txt2img_messages.append(
|
|
|
|
|
1 |
import os
|
2 |
from datetime import datetime
|
3 |
|
|
|
4 |
import streamlit as st
|
|
|
5 |
|
6 |
+
from lib import Config, HuggingFaceTxt2ImgAPI, Presets
|
7 |
|
8 |
# TODO: key input and store in cache_data
|
9 |
# TODO: API dropdown; changes available models
|
|
|
17 |
}
|
18 |
|
19 |
|
20 |
+
@st.cache_resource
|
21 |
+
def get_txt2img_api():
|
22 |
+
return HuggingFaceTxt2ImgAPI(HF_TOKEN)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
# config
|
|
|
30 |
)
|
31 |
|
32 |
# initialize state
|
33 |
+
if "txt2img_running" not in st.session_state:
|
34 |
+
st.session_state.txt2img_running = False
|
35 |
+
|
36 |
if "txt2img_messages" not in st.session_state:
|
37 |
st.session_state.txt2img_messages = []
|
38 |
|
|
|
47 |
format_func=lambda x: x.split("/")[1],
|
48 |
options=Config.TXT2IMG_MODELS,
|
49 |
index=Config.TXT2IMG_DEFAULT_MODEL,
|
50 |
+
disabled=st.session_state.txt2img_running,
|
51 |
)
|
52 |
aspect_ratio = st.sidebar.select_slider(
|
53 |
"Aspect Ratio",
|
54 |
options=list(Config.TXT2IMG_AR.keys()),
|
55 |
value=Config.TXT2IMG_DEFAULT_AR,
|
56 |
+
disabled=st.session_state.txt2img_running,
|
57 |
)
|
58 |
|
59 |
# heading
|
|
|
77 |
preset["guidance_scale_max"],
|
78 |
preset["guidance_scale"],
|
79 |
0.1,
|
80 |
+
disabled=st.session_state.txt2img_running,
|
81 |
)
|
82 |
if param == "num_inference_steps":
|
83 |
parameters[param] = st.sidebar.slider(
|
|
|
86 |
preset["num_inference_steps_max"],
|
87 |
preset["num_inference_steps"],
|
88 |
1,
|
89 |
+
disabled=st.session_state.txt2img_running,
|
90 |
)
|
91 |
if param == "seed":
|
92 |
parameters[param] = st.sidebar.number_input(
|
|
|
94 |
min_value=-1,
|
95 |
max_value=(1 << 53) - 1,
|
96 |
value=-1,
|
97 |
+
disabled=st.session_state.txt2img_running,
|
98 |
)
|
99 |
if param == "negative_prompt":
|
100 |
parameters[param] = st.sidebar.text_area(
|
101 |
label="Negative Prompt",
|
102 |
value=Config.TXT2IMG_NEGATIVE_PROMPT,
|
103 |
+
disabled=st.session_state.txt2img_running,
|
104 |
)
|
105 |
|
106 |
# wrap the prompt in an expander to display additional parameters
|
|
|
135 |
}
|
136 |
</style>
|
137 |
""")
|
138 |
+
st.write(message["content"]) # success will be image, error will be text
|
139 |
|
140 |
# button row
|
141 |
if st.session_state.txt2img_messages:
|
|
|
155 |
# retry
|
156 |
col1, col2 = st.columns(2)
|
157 |
with col1:
|
158 |
+
if (
|
159 |
+
st.button("❌", help="Delete last generation", disabled=st.session_state.txt2img_running)
|
160 |
+
and len(st.session_state.txt2img_messages) >= 2
|
161 |
+
):
|
162 |
st.session_state.txt2img_messages.pop()
|
163 |
st.session_state.txt2img_messages.pop()
|
164 |
st.rerun()
|
165 |
|
166 |
with col2:
|
167 |
+
if st.button("🗑️", help="Clear all generations", disabled=st.session_state.txt2img_running):
|
168 |
st.session_state.txt2img_messages = []
|
169 |
st.session_state.txt2img_seed = 0
|
170 |
st.rerun()
|
|
|
172 |
button_container = None
|
173 |
|
174 |
# show the prompt and spinner while loading then update state and re-render
|
175 |
+
if prompt := st.chat_input(
|
176 |
+
"What do you want to see?",
|
177 |
+
on_submit=lambda: setattr(st.session_state, "txt2img_running", True),
|
178 |
+
):
|
179 |
if "seed" in parameters and parameters["seed"] >= 0:
|
180 |
st.session_state.txt2img_seed = parameters["seed"]
|
181 |
else:
|
|
|
194 |
generate_kwargs = {"model": model, "prompt": prompt, "parameters": parameters}
|
195 |
if preset.get("kwargs") is not None:
|
196 |
generate_kwargs.update(preset["kwargs"])
|
197 |
+
api = get_txt2img_api()
|
198 |
+
image = api.generate_image(**generate_kwargs)
|
199 |
+
st.session_state.txt2img_running = False
|
200 |
|
201 |
model_name = PRESET_MODEL[model]["name"]
|
202 |
st.session_state.txt2img_messages.append(
|