amiguel commited on
Commit
ada1bc5
·
verified ·
1 Parent(s): 15c6064

Upload restructured-chat-app.py

Browse files
Files changed (1) hide show
  1. restructured-chat-app.py +112 -0
restructured-chat-app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # File: app.py
2
+ import streamlit as st
3
+ from chat_app.chat_app import ChatApp
4
+
5
+ if __name__ == "__main__":
6
+ app = ChatApp()
7
+ app.run()
8
+
9
+ # File: chat_app/chat_app.py
10
+ import streamlit as st
11
+ from datetime import datetime
12
+ import json
13
+ import time
14
+ from .model_handler import ModelHandler
15
+
16
+ class ChatApp:
17
+ def __init__(self):
18
+ st.set_page_config(page_title="Inspection Engineer Chat", page_icon="🔍", layout="wide")
19
+ self.initialize_session_state()
20
+ self.model_handler = self.load_model()
21
+
22
+ def initialize_session_state(self):
23
+ if "messages" not in st.session_state:
24
+ st.session_state.messages = [
25
+ {"role": "system", "content": "You are an experienced inspection methods engineer. Your task is to classify the following scope: analyze the scope provided in the input and determine the class item as an output."}
26
+ ]
27
+
28
+ @staticmethod
29
+ @st.cache_resource
30
+ def load_model():
31
+ return ModelHandler.load_model()
32
+
33
+ def display_message(self, role, content):
34
+ with st.chat_message(role):
35
+ st.write(content)
36
+
37
+ def get_user_input(self):
38
+ return st.chat_input("Type your message here...")
39
+
40
+ def stream_response(self, response):
41
+ placeholder = st.empty()
42
+ full_response = ""
43
+ for word in response.split():
44
+ full_response += word + " "
45
+ placeholder.markdown(full_response + "▌")
46
+ time.sleep(0.01)
47
+ placeholder.markdown(full_response)
48
+ return full_response
49
+
50
+ def save_chat_history(self):
51
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
52
+ filename = f"chat_history_{timestamp}.json"
53
+ with open(filename, "w") as f:
54
+ json.dump(st.session_state.messages, f, indent=2)
55
+ return filename
56
+
57
+ def run(self):
58
+ col1, col2 = st.columns([3, 1])
59
+ with col1:
60
+ st.title("Inspection Methods Engineer Assistant")
61
+ for message in st.session_state.messages[1:]:
62
+ self.display_message(message["role"], message["content"])
63
+
64
+ user_input = self.get_user_input()
65
+ if user_input:
66
+ self.display_message("user", user_input)
67
+ st.session_state.messages.append({"role": "user", "content": user_input})
68
+
69
+ conversation = ""
70
+ for msg in st.session_state.messages:
71
+ conversation += f"{msg['role']}: {msg['content']}\n\n"
72
+
73
+ with st.spinner("Analyzing and classifying scope..."):
74
+ response = self.model_handler.generate_response(conversation.strip())
75
+
76
+ with st.chat_message("assistant"):
77
+ full_response = self.stream_response(response)
78
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
79
+
80
+ with col2:
81
+ st.sidebar.title("Chat Options")
82
+ if st.sidebar.button("Save Chat History"):
83
+ filename = self.save_chat_history()
84
+ st.sidebar.success(f"Chat history saved to {filename}")
85
+
86
+ # File: chat_app/model_handler.py
87
+ import streamlit as st
88
+ import torch
89
+ from transformers import AutoTokenizer, AutoModelForCausalLM
90
+
91
+ class ModelHandler:
92
+ def __init__(self, model, tokenizer):
93
+ self.model = model
94
+ self.tokenizer = tokenizer
95
+
96
+ @staticmethod
97
+ def load_model():
98
+ device = "cuda" if torch.cuda.is_available() else "cpu"
99
+ st.info(f"Using device: {device}")
100
+ model_name = "amiguel/classItem-FT-llama-3-1-8b-instruct"
101
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
102
+ model = AutoModelForCausalLM.from_pretrained(
103
+ model_name,
104
+ device_map="auto",
105
+ load_in_8bit=device == "cuda"
106
+ )
107
+ return ModelHandler(model, tokenizer)
108
+
109
+ def generate_response(self, conversation):
110
+ inputs = self.tokenizer(conversation, return_tensors="pt").to(self.model.device)
111
+ outputs = self.model.generate(**inputs, max_new_tokens=100)
112
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)