amiguel commited on
Commit
cacd7f6
·
verified ·
1 Parent(s): b6b2322

Upload 3 files

Browse files
Files changed (3) hide show
  1. chat-interface.py +12 -0
  2. model-handler.py +38 -0
  3. streamlit-app.py +55 -0
chat-interface.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ class ChatInterface:
4
+ def __init__(self):
5
+ self.chat_input_key = "chat_input"
6
+
7
+ def get_user_input(self):
8
+ return st.chat_input("Type your message here...", key=self.chat_input_key)
9
+
10
+ def display_message(self, role, content):
11
+ with st.chat_message(role):
12
+ st.write(content)
model-handler.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+
4
+ class ModelHandler:
5
+ def __init__(self, model, tokenizer):
6
+ self.model = model
7
+ self.tokenizer = tokenizer
8
+
9
+ def generate_response(self, conversation):
10
+ inputs = self.tokenizer(conversation, return_tensors="pt", truncation=True, max_length=1024)
11
+
12
+ start_time = time.time()
13
+ output = ""
14
+
15
+ with torch.no_grad():
16
+ for _ in range(150): # Increased range for potentially longer responses
17
+ generated = self.model.generate(
18
+ **inputs,
19
+ max_new_tokens=1,
20
+ do_sample=True,
21
+ top_k=50,
22
+ top_p=0.95
23
+ )
24
+
25
+ new_token = generated[0, -1].item()
26
+ new_word = self.tokenizer.decode([new_token])
27
+ output += new_word
28
+
29
+ inputs = self.tokenizer(conversation + output, return_tensors="pt", truncation=True, max_length=1024)
30
+
31
+ if time.time() - start_time >= 0.01:
32
+ yield output
33
+ start_time = time.time()
34
+
35
+ if new_token == self.tokenizer.eos_token_id:
36
+ break
37
+
38
+ return output.strip()
streamlit-app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+ from chat_interface import ChatInterface
5
+ from model_handler import ModelHandler
6
+
7
+ # Set page configuration
8
+ st.set_page_config(page_title="Inspection Engineer Chat", page_icon="🔍")
9
+
10
+ # Initialize session state
11
+ if "messages" not in st.session_state:
12
+ st.session_state.messages = [
13
+ {"role": "system", "content": "You are an experienced senior inspection engineer. Your task is to analyze the scope provided in the input and determine the class item as an output."}
14
+ ]
15
+
16
+ @st.cache_resource
17
+ def load_model():
18
+ tokenizer = AutoTokenizer.from_pretrained("amiguel/classItem-FT-llama-3-1-8b-instruct")
19
+ model = AutoModelForCausalLM.from_pretrained("amiguel/classItem-FT-llama-3-1-8b-instruct")
20
+ return ModelHandler(model, tokenizer)
21
+
22
+ def main():
23
+ st.title("Inspection Engineer Assistant")
24
+
25
+ # Load model
26
+ model_handler = load_model()
27
+
28
+ # Initialize chat interface
29
+ chat_interface = ChatInterface()
30
+
31
+ # Display chat messages
32
+ for message in st.session_state.messages[1:]: # Skip the system message
33
+ chat_interface.display_message(message["role"], message["content"])
34
+
35
+ # Chat input
36
+ user_input = chat_interface.get_user_input()
37
+
38
+ if user_input:
39
+ # Add user message to chat history
40
+ st.session_state.messages.append({"role": "user", "content": user_input})
41
+ chat_interface.display_message("user", user_input)
42
+
43
+ # Prepare the full conversation context
44
+ conversation = "\n".join([f"{msg['role']}: {msg['content']}" for msg in st.session_state.messages])
45
+
46
+ # Generate response
47
+ with st.spinner("Analyzing..."):
48
+ response = model_handler.generate_response(conversation)
49
+
50
+ # Add assistant message to chat history
51
+ st.session_state.messages.append({"role": "assistant", "content": response})
52
+ chat_interface.display_message("assistant", response)
53
+
54
+ if __name__ == "__main__":
55
+ main()