adamelliotfields
commited on
Client improvements
Browse files- lib/__init__.py +3 -4
- lib/api.py +40 -79
- lib/config.py +4 -0
- lib/presets.py +2 -0
- pages/1_💬_Text_Generation.py +52 -46
- pages/2_🎨_Text_to_Image.py +49 -16
lib/__init__.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
-
from .api import
|
2 |
from .config import Config
|
3 |
from .presets import ModelPresets, ServicePresets
|
4 |
|
5 |
__all__ = [
|
6 |
"Config",
|
7 |
-
"HuggingFaceTxt2ImgAPI",
|
8 |
-
"HuggingFaceTxt2TxtAPI",
|
9 |
"ModelPresets",
|
10 |
-
"PerplexityTxt2TxtAPI",
|
11 |
"ServicePresets",
|
|
|
|
|
12 |
]
|
|
|
1 |
+
from .api import txt2img_generate, txt2txt_generate
|
2 |
from .config import Config
|
3 |
from .presets import ModelPresets, ServicePresets
|
4 |
|
5 |
__all__ = [
|
6 |
"Config",
|
|
|
|
|
7 |
"ModelPresets",
|
|
|
8 |
"ServicePresets",
|
9 |
+
"txt2img_generate",
|
10 |
+
"txt2txt_generate",
|
11 |
]
|
lib/api.py
CHANGED
@@ -1,87 +1,48 @@
|
|
1 |
import io
|
2 |
-
from abc import ABC, abstractmethod
|
3 |
|
4 |
import requests
|
5 |
import streamlit as st
|
6 |
from openai import APIError, OpenAI
|
7 |
from PIL import Image
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
)
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
return str(e)
|
40 |
-
|
41 |
-
|
42 |
-
class PerplexityTxt2TxtAPI(Txt2TxtAPI):
|
43 |
-
def __init__(self, api_key):
|
44 |
-
self.api_key = api_key
|
45 |
-
|
46 |
-
def generate_text(self, model, parameters, **kwargs):
|
47 |
-
if not self.api_key:
|
48 |
-
return "API Key is required."
|
49 |
-
client = OpenAI(
|
50 |
-
api_key=self.api_key,
|
51 |
-
base_url="https://api.perplexity.ai",
|
52 |
-
)
|
53 |
-
try:
|
54 |
-
stream = client.chat.completions.create(stream=True, model=model, **parameters, **kwargs)
|
55 |
-
return st.write_stream(stream)
|
56 |
-
except APIError as e:
|
57 |
-
return e.message
|
58 |
-
except Exception as e:
|
59 |
-
return str(e)
|
60 |
-
|
61 |
-
|
62 |
-
# essentially the same as huggingface_hub's inference client
|
63 |
-
class HuggingFaceTxt2ImgAPI(Txt2ImgAPI):
|
64 |
-
def __init__(self, token):
|
65 |
-
self.api_url = "https://api-inference.huggingface.co/models"
|
66 |
-
self.headers = {
|
67 |
-
"Authorization": f"Bearer {token}",
|
68 |
-
"X-Wait-For-Model": "true",
|
69 |
-
"X-Use-Cache": "false",
|
70 |
-
}
|
71 |
-
|
72 |
-
def generate_image(self, model, prompt, parameters, **kwargs):
|
73 |
-
try:
|
74 |
-
response = requests.post(
|
75 |
-
f"{self.api_url}/{model}",
|
76 |
-
headers=self.headers,
|
77 |
-
json={
|
78 |
-
"inputs": prompt,
|
79 |
-
"parameters": {**parameters, **kwargs},
|
80 |
-
},
|
81 |
-
)
|
82 |
-
if response.status_code == 200:
|
83 |
-
return Image.open(io.BytesIO(response.content))
|
84 |
-
else:
|
85 |
-
raise Exception(f"Error: {response.status_code} - {response.text}")
|
86 |
-
except Exception as e:
|
87 |
-
return str(e)
|
|
|
1 |
import io
|
|
|
2 |
|
3 |
import requests
|
4 |
import streamlit as st
|
5 |
from openai import APIError, OpenAI
|
6 |
from PIL import Image
|
7 |
|
8 |
+
from .config import Config
|
9 |
+
|
10 |
+
|
11 |
+
def txt2txt_generate(api_key, service, model, parameters, **kwargs):
|
12 |
+
base_url = Config.SERVICES[service]
|
13 |
+
if service == "Huggingface":
|
14 |
+
base_url = f"{base_url}/{model}/v1"
|
15 |
+
client = OpenAI(api_key=api_key, base_url=base_url)
|
16 |
+
|
17 |
+
try:
|
18 |
+
stream = client.chat.completions.create(stream=True, model=model, **parameters, **kwargs)
|
19 |
+
return st.write_stream(stream)
|
20 |
+
except APIError as e:
|
21 |
+
return e.message
|
22 |
+
except Exception as e:
|
23 |
+
return str(e)
|
24 |
+
|
25 |
+
|
26 |
+
def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
|
27 |
+
headers = {
|
28 |
+
"Authorization": f"Bearer {api_key}",
|
29 |
+
"X-Wait-For-Model": "true",
|
30 |
+
"X-Use-Cache": "false",
|
31 |
+
}
|
32 |
+
base_url = f"{Config.SERVICES[service]}/{model}"
|
33 |
+
|
34 |
+
try:
|
35 |
+
response = requests.post(
|
36 |
+
base_url,
|
37 |
+
headers=headers,
|
38 |
+
json={
|
39 |
+
"inputs": inputs,
|
40 |
+
"parameters": {**parameters, **kwargs},
|
41 |
+
},
|
42 |
)
|
43 |
+
if response.status_code == 200:
|
44 |
+
return Image.open(io.BytesIO(response.content))
|
45 |
+
else:
|
46 |
+
return f"Error: {response.status_code} - {response.text}"
|
47 |
+
except Exception as e:
|
48 |
+
return str(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lib/config.py
CHANGED
@@ -4,6 +4,10 @@ Config = SimpleNamespace(
|
|
4 |
TITLE="API Inference",
|
5 |
ICON="⚡",
|
6 |
LAYOUT="wide",
|
|
|
|
|
|
|
|
|
7 |
TXT2IMG_NEGATIVE_PROMPT="ugly, unattractive, malformed, mutated, disgusting, blurry, grainy, noisy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, autograph, trademark, sloppy, cluttered",
|
8 |
TXT2IMG_DEFAULT_MODEL=2,
|
9 |
TXT2IMG_MODELS=[
|
|
|
4 |
TITLE="API Inference",
|
5 |
ICON="⚡",
|
6 |
LAYOUT="wide",
|
7 |
+
SERVICES={
|
8 |
+
"Huggingface": "https://api-inference.huggingface.co/models",
|
9 |
+
"Perplexity": "https://api.perplexity.ai",
|
10 |
+
},
|
11 |
TXT2IMG_NEGATIVE_PROMPT="ugly, unattractive, malformed, mutated, disgusting, blurry, grainy, noisy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, autograph, trademark, sloppy, cluttered",
|
12 |
TXT2IMG_DEFAULT_MODEL=2,
|
13 |
TXT2IMG_MODELS=[
|
lib/presets.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from types import SimpleNamespace
|
2 |
|
|
|
3 |
ServicePresets = SimpleNamespace(
|
4 |
Huggingface={
|
5 |
# every service has model and system messages
|
@@ -16,6 +17,7 @@ ServicePresets = SimpleNamespace(
|
|
16 |
},
|
17 |
)
|
18 |
|
|
|
19 |
ModelPresets = SimpleNamespace(
|
20 |
FLUX_1_DEV={
|
21 |
"name": "FLUX.1 Dev",
|
|
|
1 |
from types import SimpleNamespace
|
2 |
|
3 |
+
# txt2txt services
|
4 |
ServicePresets = SimpleNamespace(
|
5 |
Huggingface={
|
6 |
# every service has model and system messages
|
|
|
17 |
},
|
18 |
)
|
19 |
|
20 |
+
# txt2img models
|
21 |
ModelPresets = SimpleNamespace(
|
22 |
FLUX_1_DEV={
|
23 |
"name": "FLUX.1 Dev",
|
pages/1_💬_Text_Generation.py
CHANGED
@@ -3,21 +3,12 @@ from datetime import datetime
|
|
3 |
|
4 |
import streamlit as st
|
5 |
|
6 |
-
from lib import Config,
|
7 |
|
8 |
HF_TOKEN = os.environ.get("HF_TOKEN") or None
|
9 |
PPLX_API_KEY = os.environ.get("PPLX_API_KEY") or None
|
10 |
|
11 |
|
12 |
-
@st.cache_resource
|
13 |
-
def get_txt2txt_api(service="Huggingface", api_key=None):
|
14 |
-
if service == "Huggingface":
|
15 |
-
return HuggingFaceTxt2TxtAPI(api_key)
|
16 |
-
if service == "Perplexity":
|
17 |
-
return PerplexityTxt2TxtAPI(api_key)
|
18 |
-
return None
|
19 |
-
|
20 |
-
|
21 |
# config
|
22 |
st.set_page_config(
|
23 |
page_title=f"{Config.TITLE} | Text Generation",
|
@@ -26,8 +17,11 @@ st.set_page_config(
|
|
26 |
)
|
27 |
|
28 |
# initialize state
|
29 |
-
if "
|
30 |
-
st.session_state.
|
|
|
|
|
|
|
31 |
|
32 |
if "txt2txt_messages" not in st.session_state:
|
33 |
st.session_state.txt2txt_messages = []
|
@@ -35,6 +29,9 @@ if "txt2txt_messages" not in st.session_state:
|
|
35 |
if "txt2txt_prompt" not in st.session_state:
|
36 |
st.session_state.txt2txt_prompt = ""
|
37 |
|
|
|
|
|
|
|
38 |
# sidebar
|
39 |
st.logo("logo.svg")
|
40 |
st.sidebar.header("Settings")
|
@@ -45,22 +42,40 @@ service = st.sidebar.selectbox(
|
|
45 |
disabled=st.session_state.txt2txt_running,
|
46 |
)
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
api_key = st.sidebar.text_input(
|
51 |
"API Key",
|
52 |
-
value="",
|
53 |
type="password",
|
54 |
help="Cleared on page refresh",
|
55 |
disabled=st.session_state.txt2txt_running,
|
|
|
56 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
model = st.sidebar.selectbox(
|
59 |
"Model",
|
60 |
-
format_func=lambda x: x.split("/")[1] if service == "Huggingface" else x,
|
61 |
-
index=Config.TXT2TXT_DEFAULT_MODEL[service],
|
62 |
options=Config.TXT2TXT_MODELS[service],
|
|
|
63 |
disabled=st.session_state.txt2txt_running,
|
|
|
64 |
)
|
65 |
system = st.sidebar.text_area(
|
66 |
"System Message",
|
@@ -70,52 +85,48 @@ system = st.sidebar.text_area(
|
|
70 |
|
71 |
# build parameters from preset
|
72 |
parameters = {}
|
73 |
-
preset = getattr(ServicePresets, service)
|
74 |
for param in preset["parameters"]:
|
75 |
if param == "max_tokens":
|
76 |
parameters[param] = st.sidebar.slider(
|
77 |
"Max Tokens",
|
|
|
|
|
78 |
min_value=512,
|
79 |
max_value=4096,
|
80 |
-
value=512,
|
81 |
-
step=128,
|
82 |
-
help="Maximum number of tokens to generate (default: 512)",
|
83 |
disabled=st.session_state.txt2txt_running,
|
|
|
84 |
)
|
85 |
if param == "temperature":
|
86 |
parameters[param] = st.sidebar.slider(
|
87 |
"Temperature",
|
|
|
|
|
88 |
min_value=0.0,
|
89 |
max_value=2.0,
|
90 |
-
value=1.0,
|
91 |
-
step=0.1,
|
92 |
-
help="Used to modulate the next token probabilities (default: 1.0)",
|
93 |
disabled=st.session_state.txt2txt_running,
|
|
|
94 |
)
|
95 |
if param == "frequency_penalty":
|
96 |
parameters[param] = st.sidebar.slider(
|
97 |
"Frequency Penalty",
|
|
|
|
|
98 |
min_value=preset["frequency_penalty_min"],
|
99 |
max_value=preset["frequency_penalty_max"],
|
100 |
-
value=preset["frequency_penalty"],
|
101 |
-
step=0.1,
|
102 |
-
help="Penalize new tokens based on their existing frequency in the text (default: 0.0)",
|
103 |
disabled=st.session_state.txt2txt_running,
|
|
|
104 |
)
|
105 |
if param == "seed":
|
106 |
parameters[param] = st.sidebar.number_input(
|
107 |
"Seed",
|
|
|
108 |
min_value=-1,
|
109 |
max_value=(1 << 53) - 1,
|
110 |
-
value=-1,
|
111 |
-
help="Make a best effort to sample deterministically (default: -1)",
|
112 |
disabled=st.session_state.txt2txt_running,
|
|
|
113 |
)
|
114 |
|
115 |
-
# random seed
|
116 |
-
if parameters.get("seed", 0) < 0:
|
117 |
-
parameters["seed"] = int(datetime.now().timestamp() * 1e6) % (1 << 53)
|
118 |
-
|
119 |
# heading
|
120 |
st.html("""
|
121 |
<h1 style="padding: 0; margin-bottom: 0.5rem">Text Generation</h1>
|
@@ -173,9 +184,8 @@ if prompt := st.chat_input(
|
|
173 |
):
|
174 |
st.session_state.txt2txt_prompt = prompt
|
175 |
|
176 |
-
if
|
177 |
-
|
178 |
-
st.markdown(st.session_state.txt2txt_prompt)
|
179 |
|
180 |
if button_container:
|
181 |
button_container.empty()
|
@@ -185,16 +195,12 @@ if st.session_state.txt2txt_prompt:
|
|
185 |
messages.append({"role": "user", "content": st.session_state.txt2txt_prompt})
|
186 |
parameters["messages"] = messages
|
187 |
|
|
|
|
|
|
|
188 |
with st.chat_message("assistant"):
|
189 |
-
|
190 |
-
|
191 |
-
key = HF_TOKEN
|
192 |
-
elif service == "Perplexity" and PPLX_API_KEY is not None:
|
193 |
-
key = PPLX_API_KEY
|
194 |
-
else:
|
195 |
-
key = api_key
|
196 |
-
api = get_txt2txt_api(service, key)
|
197 |
-
response = api.generate_text(model, parameters)
|
198 |
st.session_state.txt2txt_running = False
|
199 |
|
200 |
st.session_state.txt2txt_messages.append({"role": "user", "content": st.session_state.txt2txt_prompt})
|
|
|
3 |
|
4 |
import streamlit as st
|
5 |
|
6 |
+
from lib import Config, ServicePresets, txt2txt_generate
|
7 |
|
8 |
HF_TOKEN = os.environ.get("HF_TOKEN") or None
|
9 |
PPLX_API_KEY = os.environ.get("PPLX_API_KEY") or None
|
10 |
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
# config
|
13 |
st.set_page_config(
|
14 |
page_title=f"{Config.TITLE} | Text Generation",
|
|
|
17 |
)
|
18 |
|
19 |
# initialize state
|
20 |
+
if "api_key_huggingface" not in st.session_state:
|
21 |
+
st.session_state.api_key_huggingface = ""
|
22 |
+
|
23 |
+
if "api_key_perplexity" not in st.session_state:
|
24 |
+
st.session_state.api_key_perplexity = ""
|
25 |
|
26 |
if "txt2txt_messages" not in st.session_state:
|
27 |
st.session_state.txt2txt_messages = []
|
|
|
29 |
if "txt2txt_prompt" not in st.session_state:
|
30 |
st.session_state.txt2txt_prompt = ""
|
31 |
|
32 |
+
if "txt2txt_running" not in st.session_state:
|
33 |
+
st.session_state.txt2txt_running = False
|
34 |
+
|
35 |
# sidebar
|
36 |
st.logo("logo.svg")
|
37 |
st.sidebar.header("Settings")
|
|
|
42 |
disabled=st.session_state.txt2txt_running,
|
43 |
)
|
44 |
|
45 |
+
if service == "Huggingface" and HF_TOKEN is None:
|
46 |
+
st.session_state.api_key_huggingface = st.sidebar.text_input(
|
|
|
47 |
"API Key",
|
|
|
48 |
type="password",
|
49 |
help="Cleared on page refresh",
|
50 |
disabled=st.session_state.txt2txt_running,
|
51 |
+
value=st.session_state.api_key_huggingface,
|
52 |
)
|
53 |
+
else:
|
54 |
+
st.session_state.api_key_huggingface = None
|
55 |
+
|
56 |
+
if service == "Perplexity" and PPLX_API_KEY is None:
|
57 |
+
st.session_state.api_key_perplexity = st.sidebar.text_input(
|
58 |
+
"API Key",
|
59 |
+
type="password",
|
60 |
+
help="Cleared on page refresh",
|
61 |
+
disabled=st.session_state.txt2txt_running,
|
62 |
+
value=st.session_state.api_key_perplexity,
|
63 |
+
)
|
64 |
+
else:
|
65 |
+
st.session_state.api_key_perplexity = None
|
66 |
+
|
67 |
+
if service == "Huggingface" and HF_TOKEN is not None:
|
68 |
+
st.session_state.api_key_huggingface = HF_TOKEN
|
69 |
+
|
70 |
+
if service == "Perplexity" and PPLX_API_KEY is not None:
|
71 |
+
st.session_state.api_key_perplexity = PPLX_API_KEY
|
72 |
|
73 |
model = st.sidebar.selectbox(
|
74 |
"Model",
|
|
|
|
|
75 |
options=Config.TXT2TXT_MODELS[service],
|
76 |
+
index=Config.TXT2TXT_DEFAULT_MODEL[service],
|
77 |
disabled=st.session_state.txt2txt_running,
|
78 |
+
format_func=lambda x: x.split("/")[1] if service == "Huggingface" else x,
|
79 |
)
|
80 |
system = st.sidebar.text_area(
|
81 |
"System Message",
|
|
|
85 |
|
86 |
# build parameters from preset
|
87 |
parameters = {}
|
88 |
+
preset = getattr(ServicePresets, service, {})
|
89 |
for param in preset["parameters"]:
|
90 |
if param == "max_tokens":
|
91 |
parameters[param] = st.sidebar.slider(
|
92 |
"Max Tokens",
|
93 |
+
step=128,
|
94 |
+
value=512,
|
95 |
min_value=512,
|
96 |
max_value=4096,
|
|
|
|
|
|
|
97 |
disabled=st.session_state.txt2txt_running,
|
98 |
+
help="Maximum number of tokens to generate (default: 512)",
|
99 |
)
|
100 |
if param == "temperature":
|
101 |
parameters[param] = st.sidebar.slider(
|
102 |
"Temperature",
|
103 |
+
step=0.1,
|
104 |
+
value=1.0,
|
105 |
min_value=0.0,
|
106 |
max_value=2.0,
|
|
|
|
|
|
|
107 |
disabled=st.session_state.txt2txt_running,
|
108 |
+
help="Used to modulate the next token probabilities (default: 1.0)",
|
109 |
)
|
110 |
if param == "frequency_penalty":
|
111 |
parameters[param] = st.sidebar.slider(
|
112 |
"Frequency Penalty",
|
113 |
+
step=0.1,
|
114 |
+
value=preset["frequency_penalty"],
|
115 |
min_value=preset["frequency_penalty_min"],
|
116 |
max_value=preset["frequency_penalty_max"],
|
|
|
|
|
|
|
117 |
disabled=st.session_state.txt2txt_running,
|
118 |
+
help="Penalize new tokens based on their existing frequency in the text (default: 0.0)",
|
119 |
)
|
120 |
if param == "seed":
|
121 |
parameters[param] = st.sidebar.number_input(
|
122 |
"Seed",
|
123 |
+
value=-1,
|
124 |
min_value=-1,
|
125 |
max_value=(1 << 53) - 1,
|
|
|
|
|
126 |
disabled=st.session_state.txt2txt_running,
|
127 |
+
help="Make a best effort to sample deterministically (default: -1)",
|
128 |
)
|
129 |
|
|
|
|
|
|
|
|
|
130 |
# heading
|
131 |
st.html("""
|
132 |
<h1 style="padding: 0; margin-bottom: 0.5rem">Text Generation</h1>
|
|
|
184 |
):
|
185 |
st.session_state.txt2txt_prompt = prompt
|
186 |
|
187 |
+
if parameters.get("seed", 0) < 0:
|
188 |
+
parameters["seed"] = int(datetime.now().timestamp() * 1e6) % (1 << 53)
|
|
|
189 |
|
190 |
if button_container:
|
191 |
button_container.empty()
|
|
|
195 |
messages.append({"role": "user", "content": st.session_state.txt2txt_prompt})
|
196 |
parameters["messages"] = messages
|
197 |
|
198 |
+
with st.chat_message("user"):
|
199 |
+
st.markdown(st.session_state.txt2txt_prompt)
|
200 |
+
|
201 |
with st.chat_message("assistant"):
|
202 |
+
api_key = getattr(st.session_state, f"api_key_{service.lower()}", None)
|
203 |
+
response = txt2txt_generate(api_key, service, model, parameters)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
st.session_state.txt2txt_running = False
|
205 |
|
206 |
st.session_state.txt2txt_messages.append({"role": "user", "content": st.session_state.txt2txt_prompt})
|
pages/2_🎨_Text_to_Image.py
CHANGED
@@ -3,11 +3,10 @@ from datetime import datetime
|
|
3 |
|
4 |
import streamlit as st
|
5 |
|
6 |
-
from lib import Config,
|
7 |
|
8 |
-
# TODO: key input and store in cache_data
|
9 |
-
# TODO: API dropdown; changes available models
|
10 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
|
|
11 |
API_URL = "https://api-inference.huggingface.co/models"
|
12 |
HEADERS = {"Authorization": f"Bearer {HF_TOKEN}", "X-Wait-For-Model": "true", "X-Use-Cache": "false"}
|
13 |
PRESET_MODEL = {
|
@@ -16,12 +15,6 @@ PRESET_MODEL = {
|
|
16 |
"stabilityai/stable-diffusion-xl-base-1.0": ModelPresets.STABLE_DIFFUSION_XL,
|
17 |
}
|
18 |
|
19 |
-
|
20 |
-
@st.cache_resource
|
21 |
-
def get_txt2img_api():
|
22 |
-
return HuggingFaceTxt2ImgAPI(HF_TOKEN)
|
23 |
-
|
24 |
-
|
25 |
# config
|
26 |
st.set_page_config(
|
27 |
page_title=f"{Config.TITLE} | Text to Image",
|
@@ -30,24 +23,65 @@ st.set_page_config(
|
|
30 |
)
|
31 |
|
32 |
# initialize state
|
33 |
-
if "
|
34 |
-
st.session_state.
|
|
|
|
|
|
|
35 |
|
36 |
if "txt2img_messages" not in st.session_state:
|
37 |
st.session_state.txt2img_messages = []
|
38 |
|
|
|
|
|
|
|
39 |
if "txt2img_seed" not in st.session_state:
|
40 |
st.session_state.txt2img_seed = 0
|
41 |
|
42 |
# sidebar
|
43 |
st.logo("logo.svg")
|
44 |
st.sidebar.header("Settings")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
model = st.sidebar.selectbox(
|
46 |
"Model",
|
47 |
-
format_func=lambda x: x.split("/")[1],
|
48 |
options=Config.TXT2IMG_MODELS,
|
49 |
index=Config.TXT2IMG_DEFAULT_MODEL,
|
50 |
disabled=st.session_state.txt2img_running,
|
|
|
51 |
)
|
52 |
aspect_ratio = st.sidebar.select_slider(
|
53 |
"Aspect Ratio",
|
@@ -191,11 +225,10 @@ if prompt := st.chat_input(
|
|
191 |
|
192 |
with st.chat_message("assistant"):
|
193 |
with st.spinner("Running..."):
|
194 |
-
generate_kwargs = {"model": model, "prompt": prompt, "parameters": parameters}
|
195 |
if preset.get("kwargs") is not None:
|
196 |
-
|
197 |
-
|
198 |
-
image =
|
199 |
st.session_state.txt2img_running = False
|
200 |
|
201 |
model_name = PRESET_MODEL[model]["name"]
|
|
|
3 |
|
4 |
import streamlit as st
|
5 |
|
6 |
+
from lib import Config, ModelPresets, txt2img_generate
|
7 |
|
|
|
|
|
8 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
9 |
+
FAL_KEY = os.environ.get("FAL_KEY")
|
10 |
API_URL = "https://api-inference.huggingface.co/models"
|
11 |
HEADERS = {"Authorization": f"Bearer {HF_TOKEN}", "X-Wait-For-Model": "true", "X-Use-Cache": "false"}
|
12 |
PRESET_MODEL = {
|
|
|
15 |
"stabilityai/stable-diffusion-xl-base-1.0": ModelPresets.STABLE_DIFFUSION_XL,
|
16 |
}
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
# config
|
19 |
st.set_page_config(
|
20 |
page_title=f"{Config.TITLE} | Text to Image",
|
|
|
23 |
)
|
24 |
|
25 |
# initialize state
|
26 |
+
if "api_key_fal" not in st.session_state:
|
27 |
+
st.session_state.api_key_fal = ""
|
28 |
+
|
29 |
+
if "api_key_huggingface" not in st.session_state:
|
30 |
+
st.session_state.api_key_huggingface = ""
|
31 |
|
32 |
if "txt2img_messages" not in st.session_state:
|
33 |
st.session_state.txt2img_messages = []
|
34 |
|
35 |
+
if "txt2img_running" not in st.session_state:
|
36 |
+
st.session_state.txt2img_running = False
|
37 |
+
|
38 |
if "txt2img_seed" not in st.session_state:
|
39 |
st.session_state.txt2img_seed = 0
|
40 |
|
41 |
# sidebar
|
42 |
st.logo("logo.svg")
|
43 |
st.sidebar.header("Settings")
|
44 |
+
service = st.sidebar.selectbox(
|
45 |
+
"Service",
|
46 |
+
options=["Huggingface"],
|
47 |
+
index=0,
|
48 |
+
disabled=st.session_state.txt2img_running,
|
49 |
+
)
|
50 |
+
|
51 |
+
if service == "Huggingface" and HF_TOKEN is None:
|
52 |
+
st.session_state.api_key_huggingface = st.sidebar.text_input(
|
53 |
+
"API Key",
|
54 |
+
type="password",
|
55 |
+
help="Cleared on page refresh",
|
56 |
+
value=st.session_state.api_key_huggingface,
|
57 |
+
disabled=st.session_state.txt2txt_running,
|
58 |
+
)
|
59 |
+
else:
|
60 |
+
st.session_state.api_key_huggingface = None
|
61 |
+
|
62 |
+
if service == "Fal" and FAL_KEY is None:
|
63 |
+
st.session_state.api_key_fal = st.sidebar.text_input(
|
64 |
+
"API Key",
|
65 |
+
type="password",
|
66 |
+
help="Cleared on page refresh",
|
67 |
+
value=st.session_state.api_key_fal,
|
68 |
+
disabled=st.session_state.txt2txt_running,
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
st.session_state.api_key_fal = None
|
72 |
+
|
73 |
+
if service == "Huggingface" and HF_TOKEN is not None:
|
74 |
+
st.session_state.api_key_huggingface = HF_TOKEN
|
75 |
+
|
76 |
+
if service == "Fal" and FAL_KEY is not None:
|
77 |
+
st.session_state.api_key_fal = FAL_KEY
|
78 |
+
|
79 |
model = st.sidebar.selectbox(
|
80 |
"Model",
|
|
|
81 |
options=Config.TXT2IMG_MODELS,
|
82 |
index=Config.TXT2IMG_DEFAULT_MODEL,
|
83 |
disabled=st.session_state.txt2img_running,
|
84 |
+
format_func=lambda x: x.split("/")[1],
|
85 |
)
|
86 |
aspect_ratio = st.sidebar.select_slider(
|
87 |
"Aspect Ratio",
|
|
|
225 |
|
226 |
with st.chat_message("assistant"):
|
227 |
with st.spinner("Running..."):
|
|
|
228 |
if preset.get("kwargs") is not None:
|
229 |
+
parameters.update(preset["kwargs"])
|
230 |
+
api_key = getattr(st.session_state, f"api_key_{service.lower()}", None)
|
231 |
+
image = txt2img_generate(api_key, service, model, prompt, parameters)
|
232 |
st.session_state.txt2img_running = False
|
233 |
|
234 |
model_name = PRESET_MODEL[model]["name"]
|