adamelliotfields commited on
Commit
7f1bd15
·
verified ·
1 Parent(s): 04549f6

Client improvements

Browse files
lib/__init__.py CHANGED
@@ -1,12 +1,11 @@
1
- from .api import HuggingFaceTxt2ImgAPI, HuggingFaceTxt2TxtAPI, PerplexityTxt2TxtAPI
2
  from .config import Config
3
  from .presets import ModelPresets, ServicePresets
4
 
5
  __all__ = [
6
  "Config",
7
- "HuggingFaceTxt2ImgAPI",
8
- "HuggingFaceTxt2TxtAPI",
9
  "ModelPresets",
10
- "PerplexityTxt2TxtAPI",
11
  "ServicePresets",
 
 
12
  ]
 
1
+ from .api import txt2img_generate, txt2txt_generate
2
  from .config import Config
3
  from .presets import ModelPresets, ServicePresets
4
 
5
  __all__ = [
6
  "Config",
 
 
7
  "ModelPresets",
 
8
  "ServicePresets",
9
+ "txt2img_generate",
10
+ "txt2txt_generate",
11
  ]
lib/api.py CHANGED
@@ -1,87 +1,48 @@
1
  import io
2
- from abc import ABC, abstractmethod
3
 
4
  import requests
5
  import streamlit as st
6
  from openai import APIError, OpenAI
7
  from PIL import Image
8
 
9
-
10
- class Txt2TxtAPI(ABC):
11
- @abstractmethod
12
- def generate_text(self, model, parameters, **kwargs):
13
- pass
14
-
15
-
16
- class Txt2ImgAPI(ABC):
17
- @abstractmethod
18
- def generate_image(self, model, prompt, parameters, **kwargs):
19
- pass
20
-
21
-
22
- class HuggingFaceTxt2TxtAPI(Txt2TxtAPI):
23
- def __init__(self, api_key):
24
- self.api_key = api_key
25
-
26
- def generate_text(self, model, parameters, **kwargs):
27
- if not self.api_key:
28
- return "API Key is required."
29
- client = OpenAI(
30
- api_key=self.api_key,
31
- base_url=f"https://api-inference.huggingface.co/models/{model}/v1",
 
 
 
 
 
 
 
 
 
 
 
32
  )
33
- try:
34
- stream = client.chat.completions.create(stream=True, model=model, **parameters, **kwargs)
35
- return st.write_stream(stream)
36
- except APIError as e:
37
- return e.message
38
- except Exception as e:
39
- return str(e)
40
-
41
-
42
- class PerplexityTxt2TxtAPI(Txt2TxtAPI):
43
- def __init__(self, api_key):
44
- self.api_key = api_key
45
-
46
- def generate_text(self, model, parameters, **kwargs):
47
- if not self.api_key:
48
- return "API Key is required."
49
- client = OpenAI(
50
- api_key=self.api_key,
51
- base_url="https://api.perplexity.ai",
52
- )
53
- try:
54
- stream = client.chat.completions.create(stream=True, model=model, **parameters, **kwargs)
55
- return st.write_stream(stream)
56
- except APIError as e:
57
- return e.message
58
- except Exception as e:
59
- return str(e)
60
-
61
-
62
- # essentially the same as huggingface_hub's inference client
63
- class HuggingFaceTxt2ImgAPI(Txt2ImgAPI):
64
- def __init__(self, token):
65
- self.api_url = "https://api-inference.huggingface.co/models"
66
- self.headers = {
67
- "Authorization": f"Bearer {token}",
68
- "X-Wait-For-Model": "true",
69
- "X-Use-Cache": "false",
70
- }
71
-
72
- def generate_image(self, model, prompt, parameters, **kwargs):
73
- try:
74
- response = requests.post(
75
- f"{self.api_url}/{model}",
76
- headers=self.headers,
77
- json={
78
- "inputs": prompt,
79
- "parameters": {**parameters, **kwargs},
80
- },
81
- )
82
- if response.status_code == 200:
83
- return Image.open(io.BytesIO(response.content))
84
- else:
85
- raise Exception(f"Error: {response.status_code} - {response.text}")
86
- except Exception as e:
87
- return str(e)
 
1
  import io
 
2
 
3
  import requests
4
  import streamlit as st
5
  from openai import APIError, OpenAI
6
  from PIL import Image
7
 
8
+ from .config import Config
9
+
10
+
11
+ def txt2txt_generate(api_key, service, model, parameters, **kwargs):
12
+ base_url = Config.SERVICES[service]
13
+ if service == "Huggingface":
14
+ base_url = f"{base_url}/{model}/v1"
15
+ client = OpenAI(api_key=api_key, base_url=base_url)
16
+
17
+ try:
18
+ stream = client.chat.completions.create(stream=True, model=model, **parameters, **kwargs)
19
+ return st.write_stream(stream)
20
+ except APIError as e:
21
+ return e.message
22
+ except Exception as e:
23
+ return str(e)
24
+
25
+
26
+ def txt2img_generate(api_key, service, model, inputs, parameters, **kwargs):
27
+ headers = {
28
+ "Authorization": f"Bearer {api_key}",
29
+ "X-Wait-For-Model": "true",
30
+ "X-Use-Cache": "false",
31
+ }
32
+ base_url = f"{Config.SERVICES[service]}/{model}"
33
+
34
+ try:
35
+ response = requests.post(
36
+ base_url,
37
+ headers=headers,
38
+ json={
39
+ "inputs": inputs,
40
+ "parameters": {**parameters, **kwargs},
41
+ },
42
  )
43
+ if response.status_code == 200:
44
+ return Image.open(io.BytesIO(response.content))
45
+ else:
46
+ return f"Error: {response.status_code} - {response.text}"
47
+ except Exception as e:
48
+ return str(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/config.py CHANGED
@@ -4,6 +4,10 @@ Config = SimpleNamespace(
4
  TITLE="API Inference",
5
  ICON="⚡",
6
  LAYOUT="wide",
 
 
 
 
7
  TXT2IMG_NEGATIVE_PROMPT="ugly, unattractive, malformed, mutated, disgusting, blurry, grainy, noisy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, autograph, trademark, sloppy, cluttered",
8
  TXT2IMG_DEFAULT_MODEL=2,
9
  TXT2IMG_MODELS=[
 
4
  TITLE="API Inference",
5
  ICON="⚡",
6
  LAYOUT="wide",
7
+ SERVICES={
8
+ "Huggingface": "https://api-inference.huggingface.co/models",
9
+ "Perplexity": "https://api.perplexity.ai",
10
+ },
11
  TXT2IMG_NEGATIVE_PROMPT="ugly, unattractive, malformed, mutated, disgusting, blurry, grainy, noisy, oversaturated, undersaturated, overexposed, underexposed, worst quality, low details, lowres, watermark, signature, autograph, trademark, sloppy, cluttered",
12
  TXT2IMG_DEFAULT_MODEL=2,
13
  TXT2IMG_MODELS=[
lib/presets.py CHANGED
@@ -1,5 +1,6 @@
1
  from types import SimpleNamespace
2
 
 
3
  ServicePresets = SimpleNamespace(
4
  Huggingface={
5
  # every service has model and system messages
@@ -16,6 +17,7 @@ ServicePresets = SimpleNamespace(
16
  },
17
  )
18
 
 
19
  ModelPresets = SimpleNamespace(
20
  FLUX_1_DEV={
21
  "name": "FLUX.1 Dev",
 
1
  from types import SimpleNamespace
2
 
3
+ # txt2txt services
4
  ServicePresets = SimpleNamespace(
5
  Huggingface={
6
  # every service has model and system messages
 
17
  },
18
  )
19
 
20
+ # txt2img models
21
  ModelPresets = SimpleNamespace(
22
  FLUX_1_DEV={
23
  "name": "FLUX.1 Dev",
pages/1_💬_Text_Generation.py CHANGED
@@ -3,21 +3,12 @@ from datetime import datetime
3
 
4
  import streamlit as st
5
 
6
- from lib import Config, HuggingFaceTxt2TxtAPI, PerplexityTxt2TxtAPI, ServicePresets
7
 
8
  HF_TOKEN = os.environ.get("HF_TOKEN") or None
9
  PPLX_API_KEY = os.environ.get("PPLX_API_KEY") or None
10
 
11
 
12
- @st.cache_resource
13
- def get_txt2txt_api(service="Huggingface", api_key=None):
14
- if service == "Huggingface":
15
- return HuggingFaceTxt2TxtAPI(api_key)
16
- if service == "Perplexity":
17
- return PerplexityTxt2TxtAPI(api_key)
18
- return None
19
-
20
-
21
  # config
22
  st.set_page_config(
23
  page_title=f"{Config.TITLE} | Text Generation",
@@ -26,8 +17,11 @@ st.set_page_config(
26
  )
27
 
28
  # initialize state
29
- if "txt2txt_running" not in st.session_state:
30
- st.session_state.txt2txt_running = False
 
 
 
31
 
32
  if "txt2txt_messages" not in st.session_state:
33
  st.session_state.txt2txt_messages = []
@@ -35,6 +29,9 @@ if "txt2txt_messages" not in st.session_state:
35
  if "txt2txt_prompt" not in st.session_state:
36
  st.session_state.txt2txt_prompt = ""
37
 
 
 
 
38
  # sidebar
39
  st.logo("logo.svg")
40
  st.sidebar.header("Settings")
@@ -45,22 +42,40 @@ service = st.sidebar.selectbox(
45
  disabled=st.session_state.txt2txt_running,
46
  )
47
 
48
- # hide key input if environment variables are set
49
- if (service == "Huggingface" and HF_TOKEN is None) or (service == "Perplexity" and PPLX_API_KEY is None):
50
- api_key = st.sidebar.text_input(
51
  "API Key",
52
- value="",
53
  type="password",
54
  help="Cleared on page refresh",
55
  disabled=st.session_state.txt2txt_running,
 
56
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  model = st.sidebar.selectbox(
59
  "Model",
60
- format_func=lambda x: x.split("/")[1] if service == "Huggingface" else x,
61
- index=Config.TXT2TXT_DEFAULT_MODEL[service],
62
  options=Config.TXT2TXT_MODELS[service],
 
63
  disabled=st.session_state.txt2txt_running,
 
64
  )
65
  system = st.sidebar.text_area(
66
  "System Message",
@@ -70,52 +85,48 @@ system = st.sidebar.text_area(
70
 
71
  # build parameters from preset
72
  parameters = {}
73
- preset = getattr(ServicePresets, service)
74
  for param in preset["parameters"]:
75
  if param == "max_tokens":
76
  parameters[param] = st.sidebar.slider(
77
  "Max Tokens",
 
 
78
  min_value=512,
79
  max_value=4096,
80
- value=512,
81
- step=128,
82
- help="Maximum number of tokens to generate (default: 512)",
83
  disabled=st.session_state.txt2txt_running,
 
84
  )
85
  if param == "temperature":
86
  parameters[param] = st.sidebar.slider(
87
  "Temperature",
 
 
88
  min_value=0.0,
89
  max_value=2.0,
90
- value=1.0,
91
- step=0.1,
92
- help="Used to modulate the next token probabilities (default: 1.0)",
93
  disabled=st.session_state.txt2txt_running,
 
94
  )
95
  if param == "frequency_penalty":
96
  parameters[param] = st.sidebar.slider(
97
  "Frequency Penalty",
 
 
98
  min_value=preset["frequency_penalty_min"],
99
  max_value=preset["frequency_penalty_max"],
100
- value=preset["frequency_penalty"],
101
- step=0.1,
102
- help="Penalize new tokens based on their existing frequency in the text (default: 0.0)",
103
  disabled=st.session_state.txt2txt_running,
 
104
  )
105
  if param == "seed":
106
  parameters[param] = st.sidebar.number_input(
107
  "Seed",
 
108
  min_value=-1,
109
  max_value=(1 << 53) - 1,
110
- value=-1,
111
- help="Make a best effort to sample deterministically (default: -1)",
112
  disabled=st.session_state.txt2txt_running,
 
113
  )
114
 
115
- # random seed
116
- if parameters.get("seed", 0) < 0:
117
- parameters["seed"] = int(datetime.now().timestamp() * 1e6) % (1 << 53)
118
-
119
  # heading
120
  st.html("""
121
  <h1 style="padding: 0; margin-bottom: 0.5rem">Text Generation</h1>
@@ -173,9 +184,8 @@ if prompt := st.chat_input(
173
  ):
174
  st.session_state.txt2txt_prompt = prompt
175
 
176
- if st.session_state.txt2txt_prompt:
177
- with st.chat_message("user"):
178
- st.markdown(st.session_state.txt2txt_prompt)
179
 
180
  if button_container:
181
  button_container.empty()
@@ -185,16 +195,12 @@ if st.session_state.txt2txt_prompt:
185
  messages.append({"role": "user", "content": st.session_state.txt2txt_prompt})
186
  parameters["messages"] = messages
187
 
 
 
 
188
  with st.chat_message("assistant"):
189
- # allow environment variables in development for convenience
190
- if service == "Huggingface" and HF_TOKEN is not None:
191
- key = HF_TOKEN
192
- elif service == "Perplexity" and PPLX_API_KEY is not None:
193
- key = PPLX_API_KEY
194
- else:
195
- key = api_key
196
- api = get_txt2txt_api(service, key)
197
- response = api.generate_text(model, parameters)
198
  st.session_state.txt2txt_running = False
199
 
200
  st.session_state.txt2txt_messages.append({"role": "user", "content": st.session_state.txt2txt_prompt})
 
3
 
4
  import streamlit as st
5
 
6
+ from lib import Config, ServicePresets, txt2txt_generate
7
 
8
  HF_TOKEN = os.environ.get("HF_TOKEN") or None
9
  PPLX_API_KEY = os.environ.get("PPLX_API_KEY") or None
10
 
11
 
 
 
 
 
 
 
 
 
 
12
  # config
13
  st.set_page_config(
14
  page_title=f"{Config.TITLE} | Text Generation",
 
17
  )
18
 
19
  # initialize state
20
+ if "api_key_huggingface" not in st.session_state:
21
+ st.session_state.api_key_huggingface = ""
22
+
23
+ if "api_key_perplexity" not in st.session_state:
24
+ st.session_state.api_key_perplexity = ""
25
 
26
  if "txt2txt_messages" not in st.session_state:
27
  st.session_state.txt2txt_messages = []
 
29
  if "txt2txt_prompt" not in st.session_state:
30
  st.session_state.txt2txt_prompt = ""
31
 
32
+ if "txt2txt_running" not in st.session_state:
33
+ st.session_state.txt2txt_running = False
34
+
35
  # sidebar
36
  st.logo("logo.svg")
37
  st.sidebar.header("Settings")
 
42
  disabled=st.session_state.txt2txt_running,
43
  )
44
 
45
+ if service == "Huggingface" and HF_TOKEN is None:
46
+ st.session_state.api_key_huggingface = st.sidebar.text_input(
 
47
  "API Key",
 
48
  type="password",
49
  help="Cleared on page refresh",
50
  disabled=st.session_state.txt2txt_running,
51
+ value=st.session_state.api_key_huggingface,
52
  )
53
+ else:
54
+ st.session_state.api_key_huggingface = None
55
+
56
+ if service == "Perplexity" and PPLX_API_KEY is None:
57
+ st.session_state.api_key_perplexity = st.sidebar.text_input(
58
+ "API Key",
59
+ type="password",
60
+ help="Cleared on page refresh",
61
+ disabled=st.session_state.txt2txt_running,
62
+ value=st.session_state.api_key_perplexity,
63
+ )
64
+ else:
65
+ st.session_state.api_key_perplexity = None
66
+
67
+ if service == "Huggingface" and HF_TOKEN is not None:
68
+ st.session_state.api_key_huggingface = HF_TOKEN
69
+
70
+ if service == "Perplexity" and PPLX_API_KEY is not None:
71
+ st.session_state.api_key_perplexity = PPLX_API_KEY
72
 
73
  model = st.sidebar.selectbox(
74
  "Model",
 
 
75
  options=Config.TXT2TXT_MODELS[service],
76
+ index=Config.TXT2TXT_DEFAULT_MODEL[service],
77
  disabled=st.session_state.txt2txt_running,
78
+ format_func=lambda x: x.split("/")[1] if service == "Huggingface" else x,
79
  )
80
  system = st.sidebar.text_area(
81
  "System Message",
 
85
 
86
  # build parameters from preset
87
  parameters = {}
88
+ preset = getattr(ServicePresets, service, {})
89
  for param in preset["parameters"]:
90
  if param == "max_tokens":
91
  parameters[param] = st.sidebar.slider(
92
  "Max Tokens",
93
+ step=128,
94
+ value=512,
95
  min_value=512,
96
  max_value=4096,
 
 
 
97
  disabled=st.session_state.txt2txt_running,
98
+ help="Maximum number of tokens to generate (default: 512)",
99
  )
100
  if param == "temperature":
101
  parameters[param] = st.sidebar.slider(
102
  "Temperature",
103
+ step=0.1,
104
+ value=1.0,
105
  min_value=0.0,
106
  max_value=2.0,
 
 
 
107
  disabled=st.session_state.txt2txt_running,
108
+ help="Used to modulate the next token probabilities (default: 1.0)",
109
  )
110
  if param == "frequency_penalty":
111
  parameters[param] = st.sidebar.slider(
112
  "Frequency Penalty",
113
+ step=0.1,
114
+ value=preset["frequency_penalty"],
115
  min_value=preset["frequency_penalty_min"],
116
  max_value=preset["frequency_penalty_max"],
 
 
 
117
  disabled=st.session_state.txt2txt_running,
118
+ help="Penalize new tokens based on their existing frequency in the text (default: 0.0)",
119
  )
120
  if param == "seed":
121
  parameters[param] = st.sidebar.number_input(
122
  "Seed",
123
+ value=-1,
124
  min_value=-1,
125
  max_value=(1 << 53) - 1,
 
 
126
  disabled=st.session_state.txt2txt_running,
127
+ help="Make a best effort to sample deterministically (default: -1)",
128
  )
129
 
 
 
 
 
130
  # heading
131
  st.html("""
132
  <h1 style="padding: 0; margin-bottom: 0.5rem">Text Generation</h1>
 
184
  ):
185
  st.session_state.txt2txt_prompt = prompt
186
 
187
+ if parameters.get("seed", 0) < 0:
188
+ parameters["seed"] = int(datetime.now().timestamp() * 1e6) % (1 << 53)
 
189
 
190
  if button_container:
191
  button_container.empty()
 
195
  messages.append({"role": "user", "content": st.session_state.txt2txt_prompt})
196
  parameters["messages"] = messages
197
 
198
+ with st.chat_message("user"):
199
+ st.markdown(st.session_state.txt2txt_prompt)
200
+
201
  with st.chat_message("assistant"):
202
+ api_key = getattr(st.session_state, f"api_key_{service.lower()}", None)
203
+ response = txt2txt_generate(api_key, service, model, parameters)
 
 
 
 
 
 
 
204
  st.session_state.txt2txt_running = False
205
 
206
  st.session_state.txt2txt_messages.append({"role": "user", "content": st.session_state.txt2txt_prompt})
pages/2_🎨_Text_to_Image.py CHANGED
@@ -3,11 +3,10 @@ from datetime import datetime
3
 
4
  import streamlit as st
5
 
6
- from lib import Config, HuggingFaceTxt2ImgAPI, ModelPresets
7
 
8
- # TODO: key input and store in cache_data
9
- # TODO: API dropdown; changes available models
10
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
11
  API_URL = "https://api-inference.huggingface.co/models"
12
  HEADERS = {"Authorization": f"Bearer {HF_TOKEN}", "X-Wait-For-Model": "true", "X-Use-Cache": "false"}
13
  PRESET_MODEL = {
@@ -16,12 +15,6 @@ PRESET_MODEL = {
16
  "stabilityai/stable-diffusion-xl-base-1.0": ModelPresets.STABLE_DIFFUSION_XL,
17
  }
18
 
19
-
20
- @st.cache_resource
21
- def get_txt2img_api():
22
- return HuggingFaceTxt2ImgAPI(HF_TOKEN)
23
-
24
-
25
  # config
26
  st.set_page_config(
27
  page_title=f"{Config.TITLE} | Text to Image",
@@ -30,24 +23,65 @@ st.set_page_config(
30
  )
31
 
32
  # initialize state
33
- if "txt2img_running" not in st.session_state:
34
- st.session_state.txt2img_running = False
 
 
 
35
 
36
  if "txt2img_messages" not in st.session_state:
37
  st.session_state.txt2img_messages = []
38
 
 
 
 
39
  if "txt2img_seed" not in st.session_state:
40
  st.session_state.txt2img_seed = 0
41
 
42
  # sidebar
43
  st.logo("logo.svg")
44
  st.sidebar.header("Settings")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  model = st.sidebar.selectbox(
46
  "Model",
47
- format_func=lambda x: x.split("/")[1],
48
  options=Config.TXT2IMG_MODELS,
49
  index=Config.TXT2IMG_DEFAULT_MODEL,
50
  disabled=st.session_state.txt2img_running,
 
51
  )
52
  aspect_ratio = st.sidebar.select_slider(
53
  "Aspect Ratio",
@@ -191,11 +225,10 @@ if prompt := st.chat_input(
191
 
192
  with st.chat_message("assistant"):
193
  with st.spinner("Running..."):
194
- generate_kwargs = {"model": model, "prompt": prompt, "parameters": parameters}
195
  if preset.get("kwargs") is not None:
196
- generate_kwargs.update(preset["kwargs"])
197
- api = get_txt2img_api()
198
- image = api.generate_image(**generate_kwargs)
199
  st.session_state.txt2img_running = False
200
 
201
  model_name = PRESET_MODEL[model]["name"]
 
3
 
4
  import streamlit as st
5
 
6
+ from lib import Config, ModelPresets, txt2img_generate
7
 
 
 
8
  HF_TOKEN = os.environ.get("HF_TOKEN")
9
+ FAL_KEY = os.environ.get("FAL_KEY")
10
  API_URL = "https://api-inference.huggingface.co/models"
11
  HEADERS = {"Authorization": f"Bearer {HF_TOKEN}", "X-Wait-For-Model": "true", "X-Use-Cache": "false"}
12
  PRESET_MODEL = {
 
15
  "stabilityai/stable-diffusion-xl-base-1.0": ModelPresets.STABLE_DIFFUSION_XL,
16
  }
17
 
 
 
 
 
 
 
18
  # config
19
  st.set_page_config(
20
  page_title=f"{Config.TITLE} | Text to Image",
 
23
  )
24
 
25
  # initialize state
26
+ if "api_key_fal" not in st.session_state:
27
+ st.session_state.api_key_fal = ""
28
+
29
+ if "api_key_huggingface" not in st.session_state:
30
+ st.session_state.api_key_huggingface = ""
31
 
32
  if "txt2img_messages" not in st.session_state:
33
  st.session_state.txt2img_messages = []
34
 
35
+ if "txt2img_running" not in st.session_state:
36
+ st.session_state.txt2img_running = False
37
+
38
  if "txt2img_seed" not in st.session_state:
39
  st.session_state.txt2img_seed = 0
40
 
41
  # sidebar
42
  st.logo("logo.svg")
43
  st.sidebar.header("Settings")
44
+ service = st.sidebar.selectbox(
45
+ "Service",
46
+ options=["Huggingface"],
47
+ index=0,
48
+ disabled=st.session_state.txt2img_running,
49
+ )
50
+
51
+ if service == "Huggingface" and HF_TOKEN is None:
52
+ st.session_state.api_key_huggingface = st.sidebar.text_input(
53
+ "API Key",
54
+ type="password",
55
+ help="Cleared on page refresh",
56
+ value=st.session_state.api_key_huggingface,
57
+ disabled=st.session_state.txt2txt_running,
58
+ )
59
+ else:
60
+ st.session_state.api_key_huggingface = None
61
+
62
+ if service == "Fal" and FAL_KEY is None:
63
+ st.session_state.api_key_fal = st.sidebar.text_input(
64
+ "API Key",
65
+ type="password",
66
+ help="Cleared on page refresh",
67
+ value=st.session_state.api_key_fal,
68
+ disabled=st.session_state.txt2txt_running,
69
+ )
70
+ else:
71
+ st.session_state.api_key_fal = None
72
+
73
+ if service == "Huggingface" and HF_TOKEN is not None:
74
+ st.session_state.api_key_huggingface = HF_TOKEN
75
+
76
+ if service == "Fal" and FAL_KEY is not None:
77
+ st.session_state.api_key_fal = FAL_KEY
78
+
79
  model = st.sidebar.selectbox(
80
  "Model",
 
81
  options=Config.TXT2IMG_MODELS,
82
  index=Config.TXT2IMG_DEFAULT_MODEL,
83
  disabled=st.session_state.txt2img_running,
84
+ format_func=lambda x: x.split("/")[1],
85
  )
86
  aspect_ratio = st.sidebar.select_slider(
87
  "Aspect Ratio",
 
225
 
226
  with st.chat_message("assistant"):
227
  with st.spinner("Running..."):
 
228
  if preset.get("kwargs") is not None:
229
+ parameters.update(preset["kwargs"])
230
+ api_key = getattr(st.session_state, f"api_key_{service.lower()}", None)
231
+ image = txt2img_generate(api_key, service, model, prompt, parameters)
232
  st.session_state.txt2img_running = False
233
 
234
  model_name = PRESET_MODEL[model]["name"]