import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM import torch import time import json from datetime import datetime class ChatApp: def __init__(self): st.set_page_config(page_title="Inspection Engineer Chat", page_icon="🔍", layout="wide") self.initialize_session_state() self.model_handler = self.load_model() def initialize_session_state(self): if "messages" not in st.session_state: st.session_state.messages = [ {"role": "system", "content": "You are an experienced inspection methods engineer. Your task is to classify the following scope: analyze the scope provided in the input and determine the class item as an output."} ] @staticmethod @st.cache_resource def load_model(): device = "cuda" if torch.cuda.is_available() else "cpu" st.info(f"Using device: {device}") model_name = "amiguel/classItem-FT-llama-3-1-8b-instruct" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", load_in_8bit=device == "cuda" ) return ModelHandler(model, tokenizer) def display_message(self, role, content): with st.chat_message(role): st.write(content) def get_user_input(self): return st.chat_input("Type your message here...") def stream_response(self, response): placeholder = st.empty() full_response = "" for word in response.split(): full_response += word + " " placeholder.markdown(full_response + "▌") time.sleep(0.01) # Adjust the sleep time between 0.01 and 0.05 for desired speed placeholder.markdown(full_response) return full_response def save_chat_history(self): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"chat_history_{timestamp}.json" with open(filename, "w") as f: json.dump(st.session_state.messages, f, indent=2) return filename def run(self): col1, col2 = st.columns([3, 1]) with col1: st.title("Inspection Methods Engineer Assistant") for message in st.session_state.messages[1:]: self.display_message(message["role"], message["content"]) user_input = self.get_user_input() if user_input: self.display_message("user", user_input) st.session_state.messages.append({"role": "user", "content": user_input}) # Here's the correction in how we format the conversation for the model conversation = "" for msg in st.session_state.messages: conversation += f"{msg['role']}: {msg['content']}\n\n" with st.spinner("Analyzing and classifying scope..."): response = self.model_handler.generate_response(conversation.strip()) with st.chat_message("assistant"): full_response = self.stream_response(response) st.session_state.messages.append({"role": "assistant", "content": full_response}) with col2: st.sidebar.title("Chat Options") if st.sidebar.button("Save Chat History"): filename = self.save_chat_history() st.sidebar.success(f"Chat history saved to {filename}") class ModelHandler: def __init__(self, model, tokenizer): self.model = model self.tokenizer = tokenizer def generate_response(self, conversation): inputs = self.tokenizer(conversation, return_tensors="pt").to(self.model.device) outputs = self.model.generate(**inputs, max_new_tokens=100) return self.tokenizer.decode(outputs[0], skip_special_tokens=True) if __name__ == "__main__": app = ChatApp() app.run()