Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -15,67 +15,109 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
|
|
15 |
|
16 |
# Constants
|
17 |
CONTEXT_SIZES = {
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
}
|
24 |
|
25 |
MODEL_CONTEXT_SIZES = {
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
}
|
34 |
|
35 |
class ModelRegistry:
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
# Initialize model registry
|
81 |
model_registry = ModelRegistry()
|
@@ -208,69 +250,58 @@ def send_to_model_impl(prompt, model_selection, hf_model_choice, hf_custom_model
|
|
208 |
return error_msg, []
|
209 |
|
210 |
def send_to_hf_inference(prompt: str, model_name: str, api_key: str) -> str:
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
return f"Error with HF inference: {e}"
|
226 |
|
227 |
def send_to_groq(prompt: str, model_name: str, api_key: str) -> str:
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
model="gpt-3.5-turbo",
|
265 |
-
messages=[{"role": "user", "content": prompt}],
|
266 |
-
temperature=0.7,
|
267 |
-
max_tokens=500
|
268 |
-
)
|
269 |
-
|
270 |
-
return response.choices[0].message.content
|
271 |
-
except Exception as e:
|
272 |
-
logging.error(f"Error with OpenAI API: {e}")
|
273 |
-
return f"Error with OpenAI API: {e}"
|
274 |
|
275 |
def copy_text_js(element_id: str) -> str:
|
276 |
return f"""function() {{
|
@@ -460,50 +491,51 @@ with gr.Blocks(css="""
|
|
460 |
|
461 |
# Tab 3: Model Processing
|
462 |
with gr.Tab("3οΈβ£ Model Processing"):
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
491 |
|
492 |
-
|
493 |
-
|
494 |
-
choices=list(model_registry.groq_models.keys()),
|
495 |
-
label="π§ Groq Model"
|
496 |
-
)
|
497 |
-
groq_refresh_btn = gr.Button("π Refresh Models")
|
498 |
-
groq_api_key = gr.Textbox(
|
499 |
-
label="π Groq API Key",
|
500 |
-
type="password"
|
501 |
-
)
|
502 |
|
503 |
-
|
504 |
-
open_chatgpt_button = gr.Button("π Open ChatGPT")
|
505 |
-
|
506 |
-
with gr.Column(scale=1):
|
507 |
summary_output = gr.Textbox(
|
508 |
label="π Summary",
|
509 |
lines=15,
|
@@ -569,26 +601,31 @@ with gr.Blocks(css="""
|
|
569 |
|
570 |
def toggle_custom_model(model_name):
|
571 |
return gr.update(visible=model_name == "Custom Model")
|
572 |
-
|
573 |
-
def handle_model_change(choice):
|
574 |
-
"""Handle model selection change"""
|
575 |
-
return (
|
576 |
-
gr.update(visible=choice == "HuggingFace Inference"),
|
577 |
-
gr.update(visible=choice == "Groq API"),
|
578 |
-
gr.update(visible=choice == "OpenAI ChatGPT"),
|
579 |
-
update_context_size(choice)
|
580 |
-
)
|
581 |
|
582 |
def handle_groq_model_change(model_name):
|
583 |
"""Handle Groq model selection change"""
|
584 |
return update_context_size("Groq API", model_name)
|
585 |
|
586 |
def handle_model_selection(choice):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
587 |
return [
|
588 |
gr.update(visible=choice == "HuggingFace Inference"),
|
589 |
gr.update(visible=choice == "Groq API"),
|
590 |
gr.update(visible=choice == "OpenAI ChatGPT"),
|
591 |
-
gr.update(value=
|
592 |
]
|
593 |
|
594 |
# PDF Processing Handlers
|
@@ -766,23 +803,29 @@ with gr.Blocks(css="""
|
|
766 |
)
|
767 |
|
768 |
# Download handlers
|
769 |
-
for btn,
|
770 |
-
(
|
771 |
-
(
|
772 |
-
(download_prompt, generated_prompt, "prompt"),
|
773 |
-
(download_summary, summary_output, "summary")
|
774 |
]:
|
775 |
btn.click(
|
776 |
-
|
777 |
-
|
778 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
779 |
)
|
780 |
|
781 |
def download_file(content: str, prefix: str) -> List[str]:
|
782 |
if not content:
|
783 |
return []
|
784 |
try:
|
785 |
-
|
|
|
786 |
f.write(content)
|
787 |
return [f.name]
|
788 |
except Exception as e:
|
|
|
15 |
|
16 |
# Constants
|
17 |
CONTEXT_SIZES = {
|
18 |
+
"4K": 4096,
|
19 |
+
"8K": 8192,
|
20 |
+
"32K": 32768,
|
21 |
+
"64K": 65536,
|
22 |
+
"128K": 131072
|
23 |
}
|
24 |
|
25 |
MODEL_CONTEXT_SIZES = {
|
26 |
+
"Clipboard only": 4096,
|
27 |
+
"OpenAI ChatGPT": {
|
28 |
+
"gpt-3.5-turbo": 4096,
|
29 |
+
"gpt-4": 8192,
|
30 |
+
"gpt-4-32k": 32768
|
31 |
+
},
|
32 |
+
"HuggingFace Inference": {
|
33 |
+
"microsoft/phi-3-mini-4k-instruct": 4096,
|
34 |
+
"HuggingFaceH4/zephyr-7b-beta": 8192,
|
35 |
+
"deepseek-ai/DeepSeek-Coder-V2-Instruct": 8192,
|
36 |
+
"meta-llama/Llama-3-8b-Instruct": 8192,
|
37 |
+
"mistralai/Mistral-7B-Instruct-v0.3": 32768,
|
38 |
+
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": 32768
|
39 |
+
},
|
40 |
+
"Groq API": {
|
41 |
+
"gemma-7b-it": 8192,
|
42 |
+
"llama-3.1-70b": 32768,
|
43 |
+
"mixtral-8x7b-32768": 32768,
|
44 |
+
"llama-3.1-8b": 8192
|
45 |
+
}
|
46 |
}
|
47 |
|
48 |
class ModelRegistry:
|
49 |
+
def __init__(self):
|
50 |
+
# HuggingFace Models
|
51 |
+
self.hf_models = {
|
52 |
+
"Phi-3 Mini 4K": "microsoft/phi-3-mini-4k-instruct",
|
53 |
+
"Phi-3 Mini 128k": "microsoft/Phi-3-mini-128k-instruct",
|
54 |
+
"Zephyr 7B Beta": "HuggingFaceH4/zephyr-7b-beta",
|
55 |
+
"DeepSeek Coder V2": "deepseek-ai/DeepSeek-Coder-V2-Instruct",
|
56 |
+
"Meta Llama 3.1 8B": "meta-llama/Llama-3-8b-Instruct",
|
57 |
+
"Meta Llama 3.1 70B": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
58 |
+
"Mixtral 7B": "mistralai/Mistral-7B-Instruct-v0.3",
|
59 |
+
"Nous-Hermes": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
|
60 |
+
"Cohere Command R+": "CohereForAI/c4ai-command-r-plus",
|
61 |
+
"Aya 23-35B": "CohereForAI/aya-23-35B",
|
62 |
+
"Custom Model": ""
|
63 |
+
}
|
64 |
+
|
65 |
+
# Default Groq Models
|
66 |
+
self.default_groq_models = {
|
67 |
+
"gemma-7b-it": "gemma-7b-it",
|
68 |
+
"llama-3.1-70b-8192": "llama-3.1-70b-8192",
|
69 |
+
"llama-3.1-70b-versatile": "llama-3.1-70b-versatile",
|
70 |
+
"mixtral-8x7b-32768": "mixtral-8x7b-32768",
|
71 |
+
"llama-3.1-8b-instant": "llama-3.1-8b-instant",
|
72 |
+
"llama-3.1-70b-8192-tool-use-preview": "llama3-groq-70b-8192-tool-use-preview"
|
73 |
+
}
|
74 |
+
|
75 |
+
self.groq_models = self._fetch_groq_models()
|
76 |
+
|
77 |
+
def _fetch_groq_models(self) -> Dict[str, str]:
|
78 |
+
"""Fetch available Groq models with proper error handling"""
|
79 |
+
try:
|
80 |
+
groq_api_key = os.getenv('GROQ_API_KEY')
|
81 |
+
if not groq_api_key:
|
82 |
+
logging.warning("No GROQ_API_KEY found in environment")
|
83 |
+
return self.default_groq_models
|
84 |
+
|
85 |
+
headers = {
|
86 |
+
"Authorization": f"Bearer {groq_api_key}",
|
87 |
+
"Content-Type": "application/json"
|
88 |
+
}
|
89 |
+
|
90 |
+
response = requests.get(
|
91 |
+
"https://api.groq.com/openai/v1/models",
|
92 |
+
headers=headers,
|
93 |
+
timeout=10
|
94 |
+
)
|
95 |
+
|
96 |
+
if response.status_code == 200:
|
97 |
+
models = response.json().get("data", [])
|
98 |
+
model_dict = {model["id"]: model["id"] for model in models}
|
99 |
+
|
100 |
+
# Merge with defaults to ensure all models are available
|
101 |
+
return {**self.default_groq_models, **model_dict}
|
102 |
+
else:
|
103 |
+
logging.error(f"Failed to fetch Groq models: {response.status_code}")
|
104 |
+
return self.default_groq_models
|
105 |
+
|
106 |
+
except requests.exceptions.Timeout:
|
107 |
+
logging.error("Timeout while fetching Groq models")
|
108 |
+
return self.default_groq_models
|
109 |
+
except Exception as e:
|
110 |
+
logging.error(f"Error fetching Groq models: {e}")
|
111 |
+
return self.default_groq_models
|
112 |
+
|
113 |
+
def _get_default_groq_models(self) -> Dict[str, str]:
|
114 |
+
"""Return default Groq models"""
|
115 |
+
return self.default_groq_models
|
116 |
+
|
117 |
+
def refresh_groq_models(self) -> Dict[str, str]:
|
118 |
+
"""Refresh the list of available Groq models"""
|
119 |
+
self.groq_models = self._fetch_groq_models()
|
120 |
+
return self.groq_models
|
121 |
|
122 |
# Initialize model registry
|
123 |
model_registry = ModelRegistry()
|
|
|
250 |
return error_msg, []
|
251 |
|
252 |
def send_to_hf_inference(prompt: str, model_name: str, api_key: str) -> str:
|
253 |
+
try:
|
254 |
+
client = InferenceClient(token=api_key)
|
255 |
+
response = client.text_generation(
|
256 |
+
prompt,
|
257 |
+
model=model_name,
|
258 |
+
max_new_tokens=500,
|
259 |
+
temperature=0.7,
|
260 |
+
top_p=0.95,
|
261 |
+
repetition_penalty=1.1
|
262 |
+
)
|
263 |
+
return str(response)
|
264 |
+
except Exception as e:
|
265 |
+
logging.error(f"Error with HF inference: {e}")
|
266 |
+
return f"Error with HF inference: {e}"
|
|
|
267 |
|
268 |
def send_to_groq(prompt: str, model_name: str, api_key: str) -> str:
|
269 |
+
try:
|
270 |
+
client = Groq(api_key=api_key)
|
271 |
+
response = client.chat.completions.create(
|
272 |
+
model=model_name,
|
273 |
+
messages=[{
|
274 |
+
"role": "user",
|
275 |
+
"content": prompt
|
276 |
+
}],
|
277 |
+
temperature=0.7,
|
278 |
+
max_tokens=500,
|
279 |
+
top_p=0.95
|
280 |
+
)
|
281 |
+
return response.choices[0].message.content
|
282 |
+
except Exception as e:
|
283 |
+
logging.error(f"Error with Groq API: {e}")
|
284 |
+
return f"Error with Groq API: {e}"
|
285 |
+
|
286 |
+
def send_to_openai(prompt: str, api_key: str, model: str = "gpt-3.5-turbo") -> str:
|
287 |
+
try:
|
288 |
+
import openai
|
289 |
+
openai.api_key = api_key
|
290 |
+
|
291 |
+
response = openai.ChatCompletion.create(
|
292 |
+
model=model,
|
293 |
+
messages=[
|
294 |
+
{"role": "system", "content": "You are a helpful assistant that provides detailed responses with examples and references where appropriate."},
|
295 |
+
{"role": "user", "content": prompt}
|
296 |
+
],
|
297 |
+
temperature=0.7,
|
298 |
+
max_tokens=500,
|
299 |
+
top_p=0.95
|
300 |
+
)
|
301 |
+
return response.choices[0].message.content
|
302 |
+
except Exception as e:
|
303 |
+
logging.error(f"Error with OpenAI API: {e}")
|
304 |
+
return f"Error with OpenAI API: {e}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
|
306 |
def copy_text_js(element_id: str) -> str:
|
307 |
return f"""function() {{
|
|
|
491 |
|
492 |
# Tab 3: Model Processing
|
493 |
with gr.Tab("3οΈβ£ Model Processing"):
|
494 |
+
with gr.Row():
|
495 |
+
with gr.Column(scale=1):
|
496 |
+
model_choice = gr.Radio(
|
497 |
+
choices=list(MODEL_CONTEXT_SIZES.keys()),
|
498 |
+
value="Clipboard only",
|
499 |
+
label="π€ Provider Selection"
|
500 |
+
)
|
501 |
+
|
502 |
+
with gr.Column(visible=False) as openai_options:
|
503 |
+
openai_model = gr.Dropdown(
|
504 |
+
choices=list(MODEL_CONTEXT_SIZES["OpenAI ChatGPT"].keys()),
|
505 |
+
value="gpt-3.5-turbo",
|
506 |
+
label="OpenAI Model"
|
507 |
+
)
|
508 |
+
openai_api_key = gr.Textbox(
|
509 |
+
label="π OpenAI API Key",
|
510 |
+
type="password"
|
511 |
+
)
|
512 |
+
|
513 |
+
with gr.Column(visible=False) as hf_options:
|
514 |
+
hf_model = gr.Dropdown(
|
515 |
+
choices=list(MODEL_CONTEXT_SIZES["HuggingFace Inference"].keys()),
|
516 |
+
value="microsoft/phi-3-mini-4k-instruct",
|
517 |
+
label="HuggingFace Model"
|
518 |
+
)
|
519 |
+
hf_api_key = gr.Textbox(
|
520 |
+
label="π HuggingFace API Key",
|
521 |
+
type="password"
|
522 |
+
)
|
523 |
+
|
524 |
+
with gr.Column(visible=False) as groq_options:
|
525 |
+
groq_model = gr.Dropdown(
|
526 |
+
choices=list(MODEL_CONTEXT_SIZES["Groq API"].keys()),
|
527 |
+
value="mixtral-8x7b-32768",
|
528 |
+
label="Groq Model"
|
529 |
+
)
|
530 |
+
groq_api_key = gr.Textbox(
|
531 |
+
label="π Groq API Key",
|
532 |
+
type="password"
|
533 |
+
)
|
534 |
|
535 |
+
send_to_model_btn = gr.Button("π Send to Model", variant="primary")
|
536 |
+
open_chatgpt_button = gr.Button("π Open ChatGPT")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
537 |
|
538 |
+
with gr.Column(scale=1):
|
|
|
|
|
|
|
539 |
summary_output = gr.Textbox(
|
540 |
label="π Summary",
|
541 |
lines=15,
|
|
|
601 |
|
602 |
def toggle_custom_model(model_name):
|
603 |
return gr.update(visible=model_name == "Custom Model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
604 |
|
605 |
def handle_groq_model_change(model_name):
|
606 |
"""Handle Groq model selection change"""
|
607 |
return update_context_size("Groq API", model_name)
|
608 |
|
609 |
def handle_model_selection(choice):
|
610 |
+
"""Handle model selection and update UI"""
|
611 |
+
ctx_size = MODEL_CONTEXT_SIZES.get(choice, {})
|
612 |
+
if isinstance(ctx_size, dict):
|
613 |
+
first_model = list(ctx_size.keys())[0]
|
614 |
+
ctx_size = ctx_size[first_model]
|
615 |
+
|
616 |
+
# Update model dropdown based on provider
|
617 |
+
if choice == "OpenAI ChatGPT":
|
618 |
+
openai_model.update(choices=list(MODEL_CONTEXT_SIZES["OpenAI ChatGPT"].keys()))
|
619 |
+
elif choice == "HuggingFace Inference":
|
620 |
+
hf_model.update(choices=list(MODEL_CONTEXT_SIZES["HuggingFace Inference"].keys()))
|
621 |
+
elif choice == "Groq API":
|
622 |
+
groq_model.update(choices=list(MODEL_CONTEXT_SIZES["Groq API"].keys()))
|
623 |
+
|
624 |
return [
|
625 |
gr.update(visible=choice == "HuggingFace Inference"),
|
626 |
gr.update(visible=choice == "Groq API"),
|
627 |
gr.update(visible=choice == "OpenAI ChatGPT"),
|
628 |
+
gr.update(value=ctx_size)
|
629 |
]
|
630 |
|
631 |
# PDF Processing Handlers
|
|
|
803 |
)
|
804 |
|
805 |
# Download handlers
|
806 |
+
for btn, elem_id in [
|
807 |
+
(copy_prompt_button, "generated_prompt"),
|
808 |
+
(copy_summary_button, "summary_output")
|
|
|
|
|
809 |
]:
|
810 |
btn.click(
|
811 |
+
fn=None,
|
812 |
+
_js=f"""
|
813 |
+
() => {{
|
814 |
+
const el = document.getElementById('{elem_id}');
|
815 |
+
if (!el) return 'Element not found';
|
816 |
+
navigator.clipboard.writeText(el.value);
|
817 |
+
return 'Copied to clipboard!';
|
818 |
+
}}
|
819 |
+
""",
|
820 |
+
outputs=progress_status
|
821 |
)
|
822 |
|
823 |
def download_file(content: str, prefix: str) -> List[str]:
|
824 |
if not content:
|
825 |
return []
|
826 |
try:
|
827 |
+
filename = f"{prefix}_{int(time.time())}.txt" # Add timestamp
|
828 |
+
with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.txt', prefix=filename) as f:
|
829 |
f.write(content)
|
830 |
return [f.name]
|
831 |
except Exception as e:
|