yash commited on
Commit
54da33c
·
1 Parent(s): 0ed5782

first commit

Browse files
Files changed (2) hide show
  1. app.py +144 -0
  2. requirements.txt +81 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
+ import gradio as gr
5
+
6
+ # Set the random seed for reproducibility
7
+ torch.random.manual_seed(0)
8
+
9
+ # Load the model and tokenizer from Hugging Face
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ "microsoft/Phi-3.5-mini-instruct",
12
+ device_map="cpu",
13
+ # device_map="cuda",
14
+ torch_dtype="auto",
15
+ trust_remote_code=True,
16
+ )
17
+
18
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
19
+
20
+ # Create a text-generation pipeline
21
+ pipe = pipeline(
22
+ "text-generation",
23
+ model=model,
24
+ tokenizer=tokenizer,
25
+ )
26
+
27
+ # Define the pipeline arguments
28
+ generation_args = {
29
+ "max_new_tokens": 500,
30
+ "return_full_text": False,
31
+ "temperature": 0.0,
32
+ "do_sample": False,
33
+ }
34
+
35
+ chat_session = {}
36
+
37
+
38
+ # Function to generate responses based on the entire chat history
39
+ def generate_response(chat_history):
40
+ messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
41
+
42
+ # Append the chat history (user and assistant messages)
43
+ for user_message, assistant_message in chat_history:
44
+ messages.append({"role": "user", "content": user_message})
45
+ messages.append({"role": "assistant", "content": assistant_message})
46
+
47
+ # Generate response for the latest user message
48
+ user_input = chat_history[-1][0] # The latest user message
49
+ messages.append({"role": "user", "content": user_input})
50
+ response = pipe(messages, **generation_args)
51
+
52
+ # Append the assistant response to the chat history
53
+ assistant_response = response[0]["generated_text"]
54
+ return assistant_response
55
+
56
+
57
+ # Function to update chat
58
+ def chat(user_message,history,session):
59
+
60
+ if session == "":
61
+ return history, "Error: Session ID cannot be empty. Please start a new chat."
62
+
63
+
64
+ history = history or [] # Initialize history if empty
65
+
66
+ # Generate assistant response based on the history
67
+ assistant_message = generate_response(history + [(user_message, "")])
68
+
69
+ # Append user and assistant messages to history
70
+ history.append([user_message, assistant_message])
71
+ chat_session[session] = history
72
+
73
+ return history, ""
74
+
75
+
76
+ def get_session_list():
77
+ return list(chat_session.keys())
78
+
79
+ # Function to create new chat and return updated session list
80
+ def new_chat():
81
+ session = f'session:{len(chat_session) + 1}'
82
+ chat_session[session] = [] # Initialize empty chat history for the new session
83
+ return [], "", session, get_session_list() # Return the new session and update session list
84
+
85
+ # Function to fetch old chat session history
86
+ def old_chat(sessions):
87
+ return chat_session.get(sessions, [])
88
+
89
+ # Function to reset chat history
90
+ def reset_button():
91
+ global chat_session # Access the global chat_session
92
+ chat_session = {} # Reset the global chat_session
93
+ return [], "", "",[],"" # Reset chat history, session, and input field
94
+
95
+
96
+
97
+ with gr.Blocks(css=".small-btn {width: 100px !important;} .large-textbox {width: 100% !important;}") as demo:
98
+ gr.Markdown("# 🤖 AI Assistant")
99
+
100
+ with gr.Column():
101
+ new_chat_button = gr.Button("Start New Chat")
102
+ with gr.Row():
103
+ with gr.Column(scale=3):
104
+ chatbot = gr.Chatbot(elem_id="chatbot")
105
+ with gr.Row():
106
+ with gr.Column(scale=5):
107
+ user_input = gr.Textbox(
108
+ show_label=False,
109
+ placeholder="Type your message here...",
110
+ container=False,
111
+ elem_classes="large-textbox"
112
+ )
113
+ with gr.Column(scale=1):
114
+ send_button = gr.Button("Send", variant="primary", elem_classes="small-btn")
115
+
116
+ with gr.Column(scale=1):
117
+ session = gr.Textbox(label="Current Session", interactive=False)
118
+ session_list = gr.Dropdown(label="Available Sessions", choices=get_session_list(), allow_custom_value=True)
119
+ load_session = gr.Textbox(label="Load Session", interactive=True)
120
+ with gr.Row():
121
+ get_old_session_button = gr.Button("Load Session")
122
+ avail_session = gr.Button("Get Available Session")
123
+ reset_button_ = gr.Button("Reset All", variant="secondary")
124
+
125
+ # Button click actions
126
+ user_input.submit(chat, [user_input, chatbot, session], [chatbot, user_input])
127
+ send_button.click(chat, [user_input, chatbot, session], [chatbot, user_input]) # Send button
128
+ new_chat_button.click(new_chat, [], [chatbot, user_input, session, session_list]) # Also update the session list
129
+ get_old_session_button.click(old_chat, [load_session], [chatbot])
130
+ reset_button_.click(reset_button, [], [chatbot, session, user_input, session_list, load_session])
131
+ avail_session.click(get_session_list, [], [session_list])
132
+
133
+ # Launch the Gradio app
134
+ demo.launch()
135
+
136
+
137
+
138
+
139
+
140
+
141
+
142
+
143
+
144
+
requirements.txt ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.26.0
2
+ aiofiles==23.2.1
3
+ annotated-types==0.7.0
4
+ anyio==4.6.0
5
+ certifi==2024.8.30
6
+ charset-normalizer==3.3.2
7
+ click==8.1.7
8
+ contourpy==1.3.0
9
+ cycler==0.12.1
10
+ exceptiongroup==1.2.2
11
+ fastapi==0.115.0
12
+ ffmpy==0.4.0
13
+ filelock==3.16.1
14
+ fonttools==4.54.1
15
+ fsspec==2024.9.0
16
+ gradio==4.44.1
17
+ gradio_client==1.3.0
18
+ h11==0.14.0
19
+ httpcore==1.0.6
20
+ httpx==0.27.2
21
+ huggingface-hub==0.25.1
22
+ idna==3.10
23
+ importlib_resources==6.4.5
24
+ Jinja2==3.1.4
25
+ kiwisolver==1.4.7
26
+ markdown-it-py==3.0.0
27
+ MarkupSafe==2.1.5
28
+ matplotlib==3.9.2
29
+ mdurl==0.1.2
30
+ mpmath==1.3.0
31
+ networkx==3.3
32
+ numpy==2.1.2
33
+ nvidia-cublas-cu12==12.1.3.1
34
+ nvidia-cuda-cupti-cu12==12.1.105
35
+ nvidia-cuda-nvrtc-cu12==12.1.105
36
+ nvidia-cuda-runtime-cu12==12.1.105
37
+ nvidia-cudnn-cu12==9.1.0.70
38
+ nvidia-cufft-cu12==11.0.2.54
39
+ nvidia-curand-cu12==10.3.2.106
40
+ nvidia-cusolver-cu12==11.4.5.107
41
+ nvidia-cusparse-cu12==12.1.0.106
42
+ nvidia-nccl-cu12==2.20.5
43
+ nvidia-nvjitlink-cu12==12.6.77
44
+ nvidia-nvtx-cu12==12.1.105
45
+ orjson==3.10.7
46
+ packaging==24.1
47
+ pandas==2.2.3
48
+ pillow==10.4.0
49
+ psutil==6.0.0
50
+ pydantic==2.9.2
51
+ pydantic_core==2.23.4
52
+ pydub==0.25.1
53
+ Pygments==2.18.0
54
+ pyparsing==3.1.4
55
+ python-dateutil==2.9.0.post0
56
+ python-multipart==0.0.12
57
+ pytz==2024.2
58
+ PyYAML==6.0.2
59
+ regex==2024.9.11
60
+ requests==2.32.3
61
+ rich==13.9.2
62
+ ruff==0.6.9
63
+ safetensors==0.4.5
64
+ semantic-version==2.10.0
65
+ shellingham==1.5.4
66
+ six==1.16.0
67
+ sniffio==1.3.1
68
+ starlette==0.38.6
69
+ sympy==1.13.3
70
+ tokenizers==0.20.0
71
+ tomlkit==0.12.0
72
+ torch==2.4.1
73
+ tqdm==4.66.5
74
+ transformers==4.45.2
75
+ triton==3.0.0
76
+ typer==0.12.5
77
+ typing_extensions==4.12.2
78
+ tzdata==2024.2
79
+ urllib3==2.2.3
80
+ uvicorn==0.31.0
81
+ websockets==12.0