John6666 commited on
Commit
f788018
·
verified ·
1 Parent(s): 52f0345

Upload 7 files

Browse files
Files changed (7) hide show
  1. README.md +16 -12
  2. app.py +69 -0
  3. chatllm.py +301 -0
  4. model.py +19 -0
  5. prompt.py +5 -0
  6. requirements.txt +8 -0
  7. utils.py +37 -0
README.md CHANGED
@@ -1,12 +1,16 @@
1
- ---
2
- title: Llm Multi Demo
3
- emoji: 📚
4
- colorFrom: red
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 4.41.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
1
+ ---
2
+ title: 20+ Multi LLM Playground with Web Search
3
+ emoji: 💻🧲
4
+ colorFrom: indigo
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 4.40.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from:
11
+ - prithivMLmods/WEB-DAC
12
+ - featherless-ai/try-this-model
13
+ license: creativeml-openrail-m
14
+ ---
15
+
16
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from chatllm import (
3
+ chat_response, get_llm_model, set_llm_model, get_llm_model_info,
4
+ get_llm_language, set_llm_language, get_llm_sysprompt, get_llm_sysprompt_mode,
5
+ set_llm_sysprompt_mode,
6
+ )
7
+
8
+ # Custom CSS for Gradio app
9
+ css = '''
10
+ .gradio-container{max-width: 1000px !important}
11
+ h1{text-align:center}
12
+ footer { visibility: hidden }
13
+ '''
14
+
15
+ # Create Gradio interface
16
+ with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", css=css) as app:
17
+ with gr.Column():
18
+ with gr.Group():
19
+ chatbot = gr.Chatbot(likeable=False, show_copy_button=True, show_share_button=False, layout="bubble", container=True)
20
+ with gr.Row():
21
+ chat_query = gr.Textbox(label="Search Query", placeholder="hatsune miku", value="", scale=3)
22
+ chat_clear = gr.Button("🗑️ Clear", scale=1)
23
+ with gr.Row():
24
+ chat_msg = gr.Textbox(label="Message", placeholder="Input message with or without query and press Enter or click Sumbit.", value="", scale=3)
25
+ chat_submit = gr.Button("Submit", scale=1)
26
+ with gr.Accordion("Additional inputs", open=False):
27
+ chat_model = gr.Dropdown(choices=get_llm_model(), value=get_llm_model()[0], allow_custom_value=True, label="Model")
28
+ chat_model_info = gr.Markdown(value=get_llm_model_info(get_llm_model()[0]), label="Model info")
29
+ with gr.Row():
30
+ chat_mode = gr.Dropdown(choices=get_llm_sysprompt_mode(), value=get_llm_sysprompt_mode()[0], allow_custom_value=False, label="Mode")
31
+ chat_lang = gr.Dropdown(choices=get_llm_language(), value="language same as user input", allow_custom_value=True, label="Output language")
32
+ chat_tokens = gr.Slider(minimum=1, maximum=4096, value=2000, step=1, label="Max tokens")
33
+ chat_temp = gr.Slider(minimum=0.1, maximum=4.0, value=0.9, step=0.1, label="Temperature")
34
+ chat_topp = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
35
+ chat_fp = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="Frequency penalty")
36
+ chat_sysmsg = gr.Textbox(value=get_llm_sysprompt(), interactive=True, label="System message")
37
+ examples = gr.Examples(
38
+ examples = [
39
+ ["Describe this person.", "Kafuu Chino from Gochiusa"],
40
+ ["Hello", ""],
41
+ ],
42
+ inputs=[chat_msg, chat_query],
43
+ )
44
+ gr.Markdown(
45
+ f"""This demo was created in reference to the following demos.<br>
46
+ [prithivMLmods/WEB-DAC](https://huggingface.co/spaces/prithivMLmods/WEB-DAC),
47
+ [featherless-ai/try-this-model](https://huggingface.co/spaces/featherless-ai/try-this-model),
48
+ """
49
+ )
50
+ gr.DuplicateButton(value="Duplicate Space")
51
+ gr.Markdown(f"Just a few edits to *model.py* are all it takes to complete your own collection.")
52
+ gr.on(
53
+ triggers=[chat_msg.submit, chat_query.submit, chat_submit.click],
54
+ fn=chat_response,
55
+ inputs=[chat_msg, chatbot, chat_query, chat_tokens, chat_temp, chat_topp, chat_fp],
56
+ outputs=[chatbot],
57
+ queue=True,
58
+ show_progress="full",
59
+ trigger_mode="once",
60
+ )
61
+ chat_clear.click(lambda: (None, None, None), None, [chatbot, chat_msg, chat_query], queue=False)
62
+ chat_model.change(set_llm_model, [chat_model], [chat_model, chat_model_info], queue=True, show_progress="full")\
63
+ .success(lambda: None, None, chatbot, queue=False)
64
+ chat_mode.change(set_llm_sysprompt_mode, [chat_mode], [chat_sysmsg], queue=False)
65
+ chat_lang.change(set_llm_language, [chat_lang], [chat_sysmsg], queue=False)
66
+
67
+ if __name__ == "__main__":
68
+ app.queue()
69
+ app.launch()
chatllm.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
+ import json
3
+ from bs4 import BeautifulSoup
4
+ import requests
5
+ import gradio as gr
6
+
7
+ from model import llm_models, llm_serverless_models
8
+ from prompt import llm_system_prompt
9
+ llm_clients = {}
10
+ client_main = None
11
+ current_model = None
12
+ language_codes = {"English": "en", "Japanese": "ja", "Chinese": "zh"}
13
+ llm_languages = ["language same as user input"] + list(language_codes.keys())
14
+ llm_output_language = "language same as user input"
15
+ llm_sysprompt_mode = "Default"
16
+ server_timeout = 300
17
+
18
+ def get_llm_sysprompt():
19
+ import re
20
+ prompt = re.sub('<LANGUAGE>', llm_output_language, llm_system_prompt.get(llm_sysprompt_mode, ""))
21
+ return prompt
22
+
23
+ def get_llm_sysprompt_mode():
24
+ return list(llm_system_prompt.keys())
25
+
26
+ def set_llm_sysprompt_mode(key: str):
27
+ global llm_sysprompt_mode
28
+ if not key in llm_system_prompt.keys():
29
+ llm_sysprompt_mode = "Default"
30
+ else:
31
+ llm_sysprompt_mode = key
32
+ return gr.update(value=get_llm_sysprompt())
33
+
34
+ def get_llm_language():
35
+ return llm_languages
36
+
37
+ def set_llm_language(lang: str):
38
+ global llm_output_language
39
+ llm_output_language = lang
40
+ return gr.update(value=get_llm_sysprompt())
41
+
42
+ def get_llm_model_info(model_name):
43
+ return f'Repo: [{model_name}](https://huggingface.co/{model_name})'
44
+
45
+ # Function to extract text from a webpage
46
+ def get_text_from_html(html_content):
47
+ soup = BeautifulSoup(html_content, 'html.parser')
48
+ for tag in soup(["script", "style", "header", "footer"]):
49
+ tag.extract()
50
+ return soup.get_text(strip=True)
51
+
52
+ # Function to perform a web search
53
+ def get_language_code(s):
54
+ from langdetect import detect
55
+ lang = "en"
56
+ if llm_output_language == "language same as user input":
57
+ lang = detect(s)
58
+ elif llm_output_language in language_codes.keys():
59
+ lang = language_codes[llm_output_language]
60
+ return lang
61
+
62
+ def perform_search(query):
63
+ import urllib3
64
+ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
65
+ search_term = query
66
+ lang = get_language_code(search_term)
67
+ all_results = []
68
+ max_chars_per_page = 8000
69
+ with requests.Session() as session:
70
+ response = session.get(
71
+ url="https://www.google.com/search",
72
+ headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36 Edg/111.0.0.0"},
73
+ params={"q": search_term, "num": 3, "udm": 14, "hl": f"{lang}", "lr": f"lang_{lang}", "safe": "off", "pws": 0},
74
+ timeout=5,
75
+ verify=False,
76
+ )
77
+ response.raise_for_status()
78
+ soup = BeautifulSoup(response.text, "html.parser")
79
+ result_block = soup.find_all("div", attrs={"class": "g"})
80
+ for result in result_block:
81
+ link = result.find("a", href=True)["href"]
82
+ try:
83
+ webpage_response = session.get(link, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36 Edg/111.0.0.0"}, timeout=5, verify=False)
84
+ webpage_response.raise_for_status()
85
+ visible_text = get_text_from_html(webpage_response.text)
86
+ if len(visible_text) > max_chars_per_page:
87
+ visible_text = visible_text[:max_chars_per_page]
88
+ all_results.append({"link": link, "text": visible_text})
89
+ except requests.exceptions.RequestException:
90
+ all_results.append({"link": link, "text": None})
91
+ return all_results
92
+
93
+ # https://github.com/gradio-app/gradio/blob/main/gradio/external.py
94
+ # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client
95
+ def load_from_model(model_name: str, hf_token: str = None):
96
+ import httpx
97
+ import huggingface_hub
98
+ from gradio.exceptions import ModelNotFoundError
99
+ model_url = f"https://huggingface.co/{model_name}"
100
+ api_url = f"https://api-inference.huggingface.co/models/{model_name}"
101
+ print(f"Fetching model from: {model_url}")
102
+
103
+ headers = {"Authorization": f"Bearer {hf_token}"} if hf_token is not None else {}
104
+ response = httpx.request("GET", api_url, headers=headers)
105
+ if response.status_code != 200:
106
+ raise ModelNotFoundError(
107
+ f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `hf_token` parameter."
108
+ )
109
+ headers["X-Wait-For-Model"] = "true"
110
+ client = huggingface_hub.InferenceClient(model=model_name, headers=headers,
111
+ token=hf_token, timeout=server_timeout)
112
+ inputs = [
113
+ gr.components.Textbox(render=False),
114
+ gr.components.State(render=False),
115
+ ]
116
+ outputs = [
117
+ gr.components.Chatbot(render=False),
118
+ gr.components.State(render=False),
119
+ ]
120
+ fn = client.chat_completion
121
+
122
+ def query_huggingface_inference_endpoints(*data, **kwargs):
123
+ return fn(*data, **kwargs)
124
+
125
+ interface_info = {
126
+ "fn": query_huggingface_inference_endpoints,
127
+ "inputs": inputs,
128
+ "outputs": outputs,
129
+ "title": model_name,
130
+ }
131
+ return gr.Interface(**interface_info)
132
+
133
+ def get_status(model_name: str):
134
+ client = InferenceClient(timeout=10)
135
+ return client.get_model_status(model_name)
136
+
137
+ def load_clients():
138
+ global llm_clients
139
+ for model in llm_serverless_models:
140
+ status = get_status(model)
141
+ #print(f"HF model status: {status}")
142
+ if status is None or status.state not in ["Loadable", "Loaded"]: #
143
+ print(f"Failed to load by serverless inference API: {model}. Model state is {status.state}")
144
+ continue
145
+ try:
146
+ print(f"Fetching model by serverless inference API: {model}")
147
+ llm_clients[model] = InferenceClient(model)
148
+ except Exception as e:
149
+ print(e)
150
+ print(f"Failed to load by serverless inference API: {model}")
151
+ continue
152
+ print(f"Loaded by serverless inference API: {model}")
153
+ for model in llm_models:
154
+ if model in llm_clients.keys(): continue
155
+ status = get_status(model)
156
+ #print(f"HF model status: {status}")
157
+ if status is None or status.state not in ["Loadable", "Loaded"]: #
158
+ print(f"Failed to load: {model}. Model state is {status.state}")
159
+ continue
160
+ try:
161
+ llm_clients[model] = load_from_model(model)
162
+ except Exception as e:
163
+ print(e)
164
+ print(f"Failed to load: {model}")
165
+ continue
166
+ print(f"Loaded: {model}")
167
+
168
+ def add_client(model_name: str):
169
+ global llm_clients
170
+ try:
171
+ status = get_status(model_name)
172
+ #print(f"HF model status: {status}")
173
+ if status is None or status.state not in ["Loadable", "Loaded"]: #
174
+ print(f"Failed to load: {model_name}. Model state is {status.state}")
175
+ new_client = None
176
+ else: new_client = InferenceClient(model_name)
177
+ except Exception as e:
178
+ print(e)
179
+ new_client = None
180
+ if new_client:
181
+ print(f"Loaded by serverless inference API: {model_name}")
182
+ llm_clients[model_name] = new_client
183
+ return new_client
184
+ else:
185
+ print(f"Failed to load: {model_name}")
186
+ return llm_clients.get(llm_serverless_models[0], None)
187
+
188
+ def set_llm_model(model_name: str = llm_serverless_models[0]):
189
+ global client_main
190
+ global current_model
191
+ if model_name in llm_clients.keys():
192
+ client_main = llm_clients.get(model_name, None)
193
+ else:
194
+ client_main = add_client(model_name)
195
+ if client_main is not None:
196
+ current_model = model_name
197
+ print(f"Model selected: {model_name}")
198
+ print(f"HF model status: {get_status(model_name)}")
199
+ return model_name, get_llm_model_info(model_name)
200
+ else: return None, "None"
201
+
202
+ def get_llm_model():
203
+ return list(llm_clients.keys())
204
+
205
+ # Initialize inference clients
206
+ load_clients()
207
+ set_llm_model()
208
+ client_gemma = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
209
+
210
+ # https://huggingface.co/docs/huggingface_hub/v0.24.5/en/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion
211
+ def chat_body(message, history, query, tokens, temperature, top_p, fpenalty, web_summary):
212
+ system_prompt = get_llm_sysprompt()
213
+ if query and web_summary:
214
+ messages = []
215
+ messages.append({"role": "system", "content": system_prompt})
216
+ for msg in history:
217
+ messages.append({"role": "user", "content": str(msg[0])})
218
+ messages.append({"role": "assistant", "content": str(msg[1])})
219
+ messages.append({"role": "user", "content": f"{message}\nweb_result\n{web_summary}"})
220
+ messages.append({"role": "assistant", "content": ""})
221
+ try:
222
+ if isinstance(client_main, gr.Interface):
223
+ stream = client_main.fn(messages=messages, max_tokens=tokens, temperature=temperature,
224
+ top_p=top_p, frequency_penalty=fpenalty, stream=True)
225
+ else:
226
+ stream = client_main.chat_completion(messages=messages, max_tokens=tokens, temperature=temperature,
227
+ top_p=top_p, stream=True)
228
+ except Exception as e:
229
+ print(e)
230
+ stream = []
231
+ output = ""
232
+ for response in stream:
233
+ if response and response.choices and response.choices[0].delta.content is not None:
234
+ output += response.choices[0].delta.content
235
+ yield [(output, None)]
236
+ else:
237
+ messages = []
238
+ messages.append({"role": "system", "content": system_prompt})
239
+ for msg in history:
240
+ messages.append({"role": "user", "content": str(msg[0])})
241
+ messages.append({"role": "assistant", "content": str(msg[1])})
242
+ messages.append({"role": "user", "content": message})
243
+ messages.append({"role": "assistant", "content": ""})
244
+ try:
245
+ if isinstance(client_main, gr.Interface):
246
+ stream = client_main.fn(messages=messages, max_tokens=tokens, temperature=temperature,
247
+ top_p=top_p, stream=True)
248
+ else:
249
+ stream = client_main.chat_completion(messages=messages, max_tokens=tokens, temperature=temperature,
250
+ top_p=top_p, stream=True)
251
+ except Exception as e:
252
+ print(e)
253
+ stream = []
254
+ output = ""
255
+ for response in stream:
256
+ if response and response.choices and response.choices[0].delta.content is not None:
257
+ output += response.choices[0].delta.content
258
+ yield [(output, None)]
259
+
260
+ def get_web_summary(history, query_message):
261
+ if not query_message: return ""
262
+ func_calls = []
263
+
264
+ functions_metadata = [
265
+ {"type": "function", "function": {"name": "web_search", "description": "Search query on Google", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Web search query"}}, "required": ["query"]}}},
266
+ ]
267
+
268
+ for msg in history:
269
+ func_calls.append({"role": "user", "content": f"{str(msg[0])}"})
270
+ func_calls.append({"role": "assistant", "content": f"{str(msg[1])}"})
271
+
272
+ func_calls.append({"role": "user", "content": f'[SYSTEM] You are a helpful assistant. You have access to the following functions: \n {str(functions_metadata)}\n\nTo use these functions respond with:\n<functioncall> {{ "name": "function_name", "arguments": {{ "arg_1": "value_1", "arg_1": "value_1", ... }} }} </functioncall> [USER] {query_message}'})
273
+
274
+ response = client_gemma.chat_completion(func_calls, max_tokens=200)
275
+ response = str(response)
276
+ try:
277
+ response = response[int(response.find("{")):int(response.rindex("}"))+1]
278
+ except:
279
+ response = response[int(response.find("{")):(int(response.rfind("}"))+1)]
280
+ response = response.replace("\\n", "").replace("\\'", "'").replace('\\"', '"').replace('\\', '')
281
+ #print(f"\n{response}")
282
+
283
+ try:
284
+ json_data = json.loads(str(response))
285
+ if json_data["name"] == "web_search":
286
+ query = json_data["arguments"]["query"]
287
+ #gr.Info("Searching Web")
288
+ web_results = perform_search(query)
289
+ #gr.Info("Extracting relevant Info")
290
+ web_summary = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results if res['text']])
291
+ return web_summary
292
+ else:
293
+ return ""
294
+ except:
295
+ return ""
296
+
297
+ # Function to handle responses
298
+ def chat_response(message, history, query, tokens, temperature, top_p, fpenalty):
299
+ if history is None: history = []
300
+ yield from chat_body(message, history, query, tokens, temperature, top_p, fpenalty, get_web_summary(history, query))
301
+
model.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import find_model_list, list_uniq
2
+
3
+ llm_serverless_models = [
4
+ "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
5
+ ]
6
+
7
+ llm_models = [
8
+ "mistralai/Mistral-7B-Instruct-v0.3",
9
+ ]
10
+
11
+ #llm_models.extend(find_model_list("Casual-Autopsy"))
12
+ llm_models.extend(find_model_list("", [], "gguf", "downloads", 60, True))
13
+ llm_models = list_uniq(llm_models)
14
+
15
+ # Examples:
16
+ #llm_models = ['mistralai/Mistral-7B-Instruct-v0.3', 'NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO'] # specific models
17
+ #models = find_model_list("NousResearch", [], "", "last_modified", 20) # NousResearch's latest 20 models
18
+ #models = find_model_list("", [], "", "last_modified", 20) # latest 20 text-generation models of huggingface
19
+ #models = find_model_list("", [], "", "downloads", 20) # monthly most downloaded 20 text-generation models of huggingface
prompt.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ llm_system_prompt = {"Default": r"You are a helpful AI assistant. Respond in <LANGUAGE>.",
3
+ "WEB DAC": r"Web Dac uses the user agents of Mozilla, AppleWebKit, and Safari browsers for chat responses and human context mimicking.",
4
+ #"Your new prompt": r"You are a helpful AI assistant."
5
+ }
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub
2
+ bs4
3
+ pillow
4
+ torch
5
+ git+https://github.com/huggingface/transformers.git
6
+ opencv-python
7
+ accelerate
8
+ langdetect
utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def list_uniq(l):
3
+ return sorted(set(l), key=l.index)
4
+
5
+ def get_status(model_name: str):
6
+ from huggingface_hub import InferenceClient
7
+ client = InferenceClient(timeout=10)
8
+ return client.get_model_status(model_name)
9
+
10
+ def is_loadable(model_name: str, force_gpu: bool = False):
11
+ status = get_status(model_name)
12
+ gpu_state = isinstance(status.compute_type, dict) and "gpu" in status.compute_type.keys()
13
+ if status is None or status.state not in ["Loadable", "Loaded"] or (force_gpu and not gpu_state):
14
+ print(f"Couldn't load {model_name}. Model state:'{status.state}', GPU:{gpu_state}")
15
+ return status is not None and status.state in ["Loadable", "Loaded"] and (not force_gpu or gpu_state)
16
+
17
+ def find_model_list(author: str="", tags: list[str]=[], not_tag="", sort: str="last_modified", limit: int=30, force_gpu=True):
18
+ from huggingface_hub import HfApi
19
+ api = HfApi()
20
+ #default_tags = ["transformers"]
21
+ default_tags = []
22
+ if not sort: sort = "last_modified"
23
+ models = []
24
+ limit = limit * 20 if force_gpu else limit * 5
25
+ try:
26
+ model_infos = api.list_models(author=author, pipeline_tag="text-generation",
27
+ tags=list_uniq(default_tags + tags), cardData=True, sort=sort, limit=limit)
28
+ except Exception as e:
29
+ print(f"Error: Failed to list models.")
30
+ print(e)
31
+ return models
32
+ for model in model_infos:
33
+ if not model.private and not model.gated:
34
+ if not_tag and not_tag in model.tags or not is_loadable(model.id, force_gpu): continue
35
+ models.append(model.id)
36
+ if len(models) == limit: break
37
+ return models