amiguel commited on
Commit
ab222f8
·
verified ·
1 Parent(s): 23630fe

Upload clean-chat-interface.py

Browse files
Files changed (1) hide show
  1. clean-chat-interface.py +103 -0
clean-chat-interface.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+ import time
5
+ import json
6
+ from datetime import datetime
7
+
8
+ class ChatApp:
9
+ def __init__(self):
10
+ st.set_page_config(page_title="Inspection Methods Engineer Assistant", page_icon="🔍", layout="wide")
11
+ self.initialize_session_state()
12
+ self.model_handler = self.load_model()
13
+
14
+ def initialize_session_state(self):
15
+ if "messages" not in st.session_state:
16
+ st.session_state.messages = [
17
+ {"role": "system", "content": "You are an experienced inspection methods engineer. Your task is to classify the following scope: analyze the scope provided in the input and determine the class item as an output."}
18
+ ]
19
+
20
+ @staticmethod
21
+ @st.cache_resource
22
+ def load_model():
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ st.info(f"Using device: {device}")
25
+ model_name = "amiguel/classItem-FT-llama-3-1-8b-instruct"
26
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ model_name,
29
+ device_map="auto",
30
+ load_in_8bit=device == "cuda"
31
+ )
32
+ return ModelHandler(model, tokenizer)
33
+
34
+ def display_message(self, role, content):
35
+ with st.chat_message(role):
36
+ st.markdown(content)
37
+
38
+ def get_user_input(self):
39
+ return st.chat_input("Type your message here...")
40
+
41
+ def stream_response(self, response):
42
+ placeholder = st.empty()
43
+ full_response = ""
44
+ for word in response.split():
45
+ full_response += word + " "
46
+ placeholder.markdown(full_response + "▌")
47
+ time.sleep(0.01)
48
+ placeholder.markdown(full_response)
49
+ return full_response
50
+
51
+ def save_chat_history(self):
52
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
53
+ filename = f"chat_history_{timestamp}.json"
54
+ with open(filename, "w") as f:
55
+ json.dump(st.session_state.messages, f, indent=2)
56
+ return filename
57
+
58
+ def run(self):
59
+ st.title("Inspection Methods Engineer Assistant")
60
+
61
+ for message in st.session_state.messages:
62
+ if message["role"] != "system":
63
+ self.display_message(message["role"], message["content"])
64
+
65
+ user_input = self.get_user_input()
66
+ if user_input:
67
+ self.display_message("user", user_input)
68
+ st.session_state.messages.append({"role": "user", "content": user_input})
69
+
70
+ conversation = "\n\n".join([msg["content"] for msg in st.session_state.messages])
71
+
72
+ with st.spinner("Analyzing and classifying scope..."):
73
+ response = self.model_handler.generate_response(conversation.strip())
74
+
75
+ clean_response = self.clean_response(response)
76
+ with st.chat_message("assistant"):
77
+ full_response = self.stream_response(clean_response)
78
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
79
+
80
+ st.sidebar.title("Chat Options")
81
+ if st.sidebar.button("Save Chat History"):
82
+ filename = self.save_chat_history()
83
+ st.sidebar.success(f"Chat history saved to {filename}")
84
+
85
+ def clean_response(self, response):
86
+ # Remove any system: or user: prefixes from the response
87
+ lines = response.split('\n')
88
+ clean_lines = [line.split(':', 1)[-1].strip() if ':' in line else line for line in lines]
89
+ return '\n'.join(clean_lines)
90
+
91
+ class ModelHandler:
92
+ def __init__(self, model, tokenizer):
93
+ self.model = model
94
+ self.tokenizer = tokenizer
95
+
96
+ def generate_response(self, conversation):
97
+ inputs = self.tokenizer(conversation, return_tensors="pt").to(self.model.device)
98
+ outputs = self.model.generate(**inputs, max_new_tokens=100)
99
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
100
+
101
+ if __name__ == "__main__":
102
+ app = ChatApp()
103
+ app.run()