Spaces:
Running
Running
yash
commited on
Commit
·
54da33c
1
Parent(s):
0ed5782
first commit
Browse files- app.py +144 -0
- requirements.txt +81 -0
app.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
# Set the random seed for reproducibility
|
7 |
+
torch.random.manual_seed(0)
|
8 |
+
|
9 |
+
# Load the model and tokenizer from Hugging Face
|
10 |
+
model = AutoModelForCausalLM.from_pretrained(
|
11 |
+
"microsoft/Phi-3.5-mini-instruct",
|
12 |
+
device_map="cpu",
|
13 |
+
# device_map="cuda",
|
14 |
+
torch_dtype="auto",
|
15 |
+
trust_remote_code=True,
|
16 |
+
)
|
17 |
+
|
18 |
+
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
|
19 |
+
|
20 |
+
# Create a text-generation pipeline
|
21 |
+
pipe = pipeline(
|
22 |
+
"text-generation",
|
23 |
+
model=model,
|
24 |
+
tokenizer=tokenizer,
|
25 |
+
)
|
26 |
+
|
27 |
+
# Define the pipeline arguments
|
28 |
+
generation_args = {
|
29 |
+
"max_new_tokens": 500,
|
30 |
+
"return_full_text": False,
|
31 |
+
"temperature": 0.0,
|
32 |
+
"do_sample": False,
|
33 |
+
}
|
34 |
+
|
35 |
+
chat_session = {}
|
36 |
+
|
37 |
+
|
38 |
+
# Function to generate responses based on the entire chat history
|
39 |
+
def generate_response(chat_history):
|
40 |
+
messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
|
41 |
+
|
42 |
+
# Append the chat history (user and assistant messages)
|
43 |
+
for user_message, assistant_message in chat_history:
|
44 |
+
messages.append({"role": "user", "content": user_message})
|
45 |
+
messages.append({"role": "assistant", "content": assistant_message})
|
46 |
+
|
47 |
+
# Generate response for the latest user message
|
48 |
+
user_input = chat_history[-1][0] # The latest user message
|
49 |
+
messages.append({"role": "user", "content": user_input})
|
50 |
+
response = pipe(messages, **generation_args)
|
51 |
+
|
52 |
+
# Append the assistant response to the chat history
|
53 |
+
assistant_response = response[0]["generated_text"]
|
54 |
+
return assistant_response
|
55 |
+
|
56 |
+
|
57 |
+
# Function to update chat
|
58 |
+
def chat(user_message,history,session):
|
59 |
+
|
60 |
+
if session == "":
|
61 |
+
return history, "Error: Session ID cannot be empty. Please start a new chat."
|
62 |
+
|
63 |
+
|
64 |
+
history = history or [] # Initialize history if empty
|
65 |
+
|
66 |
+
# Generate assistant response based on the history
|
67 |
+
assistant_message = generate_response(history + [(user_message, "")])
|
68 |
+
|
69 |
+
# Append user and assistant messages to history
|
70 |
+
history.append([user_message, assistant_message])
|
71 |
+
chat_session[session] = history
|
72 |
+
|
73 |
+
return history, ""
|
74 |
+
|
75 |
+
|
76 |
+
def get_session_list():
|
77 |
+
return list(chat_session.keys())
|
78 |
+
|
79 |
+
# Function to create new chat and return updated session list
|
80 |
+
def new_chat():
|
81 |
+
session = f'session:{len(chat_session) + 1}'
|
82 |
+
chat_session[session] = [] # Initialize empty chat history for the new session
|
83 |
+
return [], "", session, get_session_list() # Return the new session and update session list
|
84 |
+
|
85 |
+
# Function to fetch old chat session history
|
86 |
+
def old_chat(sessions):
|
87 |
+
return chat_session.get(sessions, [])
|
88 |
+
|
89 |
+
# Function to reset chat history
|
90 |
+
def reset_button():
|
91 |
+
global chat_session # Access the global chat_session
|
92 |
+
chat_session = {} # Reset the global chat_session
|
93 |
+
return [], "", "",[],"" # Reset chat history, session, and input field
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
with gr.Blocks(css=".small-btn {width: 100px !important;} .large-textbox {width: 100% !important;}") as demo:
|
98 |
+
gr.Markdown("# 🤖 AI Assistant")
|
99 |
+
|
100 |
+
with gr.Column():
|
101 |
+
new_chat_button = gr.Button("Start New Chat")
|
102 |
+
with gr.Row():
|
103 |
+
with gr.Column(scale=3):
|
104 |
+
chatbot = gr.Chatbot(elem_id="chatbot")
|
105 |
+
with gr.Row():
|
106 |
+
with gr.Column(scale=5):
|
107 |
+
user_input = gr.Textbox(
|
108 |
+
show_label=False,
|
109 |
+
placeholder="Type your message here...",
|
110 |
+
container=False,
|
111 |
+
elem_classes="large-textbox"
|
112 |
+
)
|
113 |
+
with gr.Column(scale=1):
|
114 |
+
send_button = gr.Button("Send", variant="primary", elem_classes="small-btn")
|
115 |
+
|
116 |
+
with gr.Column(scale=1):
|
117 |
+
session = gr.Textbox(label="Current Session", interactive=False)
|
118 |
+
session_list = gr.Dropdown(label="Available Sessions", choices=get_session_list(), allow_custom_value=True)
|
119 |
+
load_session = gr.Textbox(label="Load Session", interactive=True)
|
120 |
+
with gr.Row():
|
121 |
+
get_old_session_button = gr.Button("Load Session")
|
122 |
+
avail_session = gr.Button("Get Available Session")
|
123 |
+
reset_button_ = gr.Button("Reset All", variant="secondary")
|
124 |
+
|
125 |
+
# Button click actions
|
126 |
+
user_input.submit(chat, [user_input, chatbot, session], [chatbot, user_input])
|
127 |
+
send_button.click(chat, [user_input, chatbot, session], [chatbot, user_input]) # Send button
|
128 |
+
new_chat_button.click(new_chat, [], [chatbot, user_input, session, session_list]) # Also update the session list
|
129 |
+
get_old_session_button.click(old_chat, [load_session], [chatbot])
|
130 |
+
reset_button_.click(reset_button, [], [chatbot, session, user_input, session_list, load_session])
|
131 |
+
avail_session.click(get_session_list, [], [session_list])
|
132 |
+
|
133 |
+
# Launch the Gradio app
|
134 |
+
demo.launch()
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
|
144 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.26.0
|
2 |
+
aiofiles==23.2.1
|
3 |
+
annotated-types==0.7.0
|
4 |
+
anyio==4.6.0
|
5 |
+
certifi==2024.8.30
|
6 |
+
charset-normalizer==3.3.2
|
7 |
+
click==8.1.7
|
8 |
+
contourpy==1.3.0
|
9 |
+
cycler==0.12.1
|
10 |
+
exceptiongroup==1.2.2
|
11 |
+
fastapi==0.115.0
|
12 |
+
ffmpy==0.4.0
|
13 |
+
filelock==3.16.1
|
14 |
+
fonttools==4.54.1
|
15 |
+
fsspec==2024.9.0
|
16 |
+
gradio==4.44.1
|
17 |
+
gradio_client==1.3.0
|
18 |
+
h11==0.14.0
|
19 |
+
httpcore==1.0.6
|
20 |
+
httpx==0.27.2
|
21 |
+
huggingface-hub==0.25.1
|
22 |
+
idna==3.10
|
23 |
+
importlib_resources==6.4.5
|
24 |
+
Jinja2==3.1.4
|
25 |
+
kiwisolver==1.4.7
|
26 |
+
markdown-it-py==3.0.0
|
27 |
+
MarkupSafe==2.1.5
|
28 |
+
matplotlib==3.9.2
|
29 |
+
mdurl==0.1.2
|
30 |
+
mpmath==1.3.0
|
31 |
+
networkx==3.3
|
32 |
+
numpy==2.1.2
|
33 |
+
nvidia-cublas-cu12==12.1.3.1
|
34 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
35 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
36 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
37 |
+
nvidia-cudnn-cu12==9.1.0.70
|
38 |
+
nvidia-cufft-cu12==11.0.2.54
|
39 |
+
nvidia-curand-cu12==10.3.2.106
|
40 |
+
nvidia-cusolver-cu12==11.4.5.107
|
41 |
+
nvidia-cusparse-cu12==12.1.0.106
|
42 |
+
nvidia-nccl-cu12==2.20.5
|
43 |
+
nvidia-nvjitlink-cu12==12.6.77
|
44 |
+
nvidia-nvtx-cu12==12.1.105
|
45 |
+
orjson==3.10.7
|
46 |
+
packaging==24.1
|
47 |
+
pandas==2.2.3
|
48 |
+
pillow==10.4.0
|
49 |
+
psutil==6.0.0
|
50 |
+
pydantic==2.9.2
|
51 |
+
pydantic_core==2.23.4
|
52 |
+
pydub==0.25.1
|
53 |
+
Pygments==2.18.0
|
54 |
+
pyparsing==3.1.4
|
55 |
+
python-dateutil==2.9.0.post0
|
56 |
+
python-multipart==0.0.12
|
57 |
+
pytz==2024.2
|
58 |
+
PyYAML==6.0.2
|
59 |
+
regex==2024.9.11
|
60 |
+
requests==2.32.3
|
61 |
+
rich==13.9.2
|
62 |
+
ruff==0.6.9
|
63 |
+
safetensors==0.4.5
|
64 |
+
semantic-version==2.10.0
|
65 |
+
shellingham==1.5.4
|
66 |
+
six==1.16.0
|
67 |
+
sniffio==1.3.1
|
68 |
+
starlette==0.38.6
|
69 |
+
sympy==1.13.3
|
70 |
+
tokenizers==0.20.0
|
71 |
+
tomlkit==0.12.0
|
72 |
+
torch==2.4.1
|
73 |
+
tqdm==4.66.5
|
74 |
+
transformers==4.45.2
|
75 |
+
triton==3.0.0
|
76 |
+
typer==0.12.5
|
77 |
+
typing_extensions==4.12.2
|
78 |
+
tzdata==2024.2
|
79 |
+
urllib3==2.2.3
|
80 |
+
uvicorn==0.31.0
|
81 |
+
websockets==12.0
|