import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM import torch import time import json from datetime import datetime class ChatApp: def __init__(self): st.set_page_config(page_title="Inspection Methods Engineer Assistant", page_icon="🔍", layout="wide") self.initialize_session_state() self.model_handler = self.load_model() def initialize_session_state(self): if "messages" not in st.session_state: st.session_state.messages = [ {"role": "system", "content": "You are an experienced inspection methods engineer. Your task is to classify the following scope: analyze the scope provided in the input and determine the class item as an output."} ] @staticmethod @st.cache_resource def load_model(): device = "cuda" if torch.cuda.is_available() else "cpu" st.info(f"Using device: {device}") model_name = "amiguel/classItem-FT-llama-3-1-8b-instruct" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", load_in_8bit=device == "cuda" ) return ModelHandler(model, tokenizer) def display_message(self, role, content): with st.chat_message(role): st.markdown(content) def get_user_input(self): return st.chat_input("Type your message here...") def stream_response(self, response): placeholder = st.empty() full_response = "" for word in response.split(): full_response += word + " " placeholder.markdown(full_response + "▌") time.sleep(0.01) placeholder.markdown(full_response) return full_response def save_chat_history(self): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"chat_history_{timestamp}.json" with open(filename, "w") as f: json.dump(st.session_state.messages, f, indent=2) return filename def run(self): st.title("Inspection Methods Engineer Assistant") for message in st.session_state.messages: if message["role"] != "system": self.display_message(message["role"], message["content"]) user_input = self.get_user_input() if user_input: self.display_message("user", user_input) st.session_state.messages.append({"role": "user", "content": user_input}) conversation = "\n\n".join([msg["content"] for msg in st.session_state.messages]) with st.spinner("Analyzing and classifying scope..."): response = self.model_handler.generate_response(conversation.strip()) clean_response = self.clean_response(response) with st.chat_message("assistant"): full_response = self.stream_response(clean_response) st.session_state.messages.append({"role": "assistant", "content": full_response}) st.sidebar.title("Chat Options") if st.sidebar.button("Save Chat History"): filename = self.save_chat_history() st.sidebar.success(f"Chat history saved to {filename}") def clean_response(self, response): # Remove any system: or user: prefixes from the response lines = response.split('\n') clean_lines = [line.split(':', 1)[-1].strip() if ':' in line else line for line in lines] return '\n'.join(clean_lines) class ModelHandler: def __init__(self, model, tokenizer): self.model = model self.tokenizer = tokenizer def generate_response(self, conversation): inputs = self.tokenizer(conversation, return_tensors="pt").to(self.model.device) outputs = self.model.generate(**inputs, max_new_tokens=100) return self.tokenizer.decode(outputs[0], skip_special_tokens=True) if __name__ == "__main__": app = ChatApp() app.run()