amiguel's picture
Rename app.py to app1.py
83e2086 verified
raw
history blame
3.99 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import time
import json
from datetime import datetime
class ChatApp:
def __init__(self):
st.set_page_config(page_title="Inspection Engineer Chat", page_icon="πŸ”", layout="wide")
self.initialize_session_state()
self.model_handler = self.load_model()
def initialize_session_state(self):
if "messages" not in st.session_state:
st.session_state.messages = [
{"role": "system", "content": "You are an experienced inspection methods engineer. Your task is to classify the following scope: analyze the scope provided in the input and determine the class item as an output."}
]
@staticmethod
@st.cache_resource
def load_model():
device = "cuda" if torch.cuda.is_available() else "cpu"
st.info(f"Using device: {device}")
model_name = "amiguel/classItem-FT-llama-3-1-8b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
load_in_8bit=device == "cuda"
)
return ModelHandler(model, tokenizer)
def display_message(self, role, content):
with st.chat_message(role):
st.write(content)
def get_user_input(self):
return st.chat_input("Type your message here...")
def stream_response(self, response):
placeholder = st.empty()
full_response = ""
for word in response.split():
full_response += word + " "
placeholder.markdown(full_response + "β–Œ")
time.sleep(0.01) # Adjust the sleep time between 0.01 and 0.05 for desired speed
placeholder.markdown(full_response)
return full_response
def save_chat_history(self):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"chat_history_{timestamp}.json"
with open(filename, "w") as f:
json.dump(st.session_state.messages, f, indent=2)
return filename
def run(self):
col1, col2 = st.columns([3, 1])
with col1:
st.title("Inspection Methods Engineer Assistant")
for message in st.session_state.messages[1:]:
self.display_message(message["role"], message["content"])
user_input = self.get_user_input()
if user_input:
self.display_message("user", user_input)
st.session_state.messages.append({"role": "user", "content": user_input})
# Here's the correction in how we format the conversation for the model
conversation = ""
for msg in st.session_state.messages:
conversation += f"{msg['role']}: {msg['content']}\n\n"
with st.spinner("Analyzing and classifying scope..."):
response = self.model_handler.generate_response(conversation.strip())
with st.chat_message("assistant"):
full_response = self.stream_response(response)
st.session_state.messages.append({"role": "assistant", "content": full_response})
with col2:
st.sidebar.title("Chat Options")
if st.sidebar.button("Save Chat History"):
filename = self.save_chat_history()
st.sidebar.success(f"Chat history saved to {filename}")
class ModelHandler:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def generate_response(self, conversation):
inputs = self.tokenizer(conversation, return_tensors="pt").to(self.model.device)
outputs = self.model.generate(**inputs, max_new_tokens=100)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
if __name__ == "__main__":
app = ChatApp()
app.run()