KiddoTheBERTo / app.bak
JustKiddo's picture
Create app.bak
9d2a89b verified
import streamlit as st
import requests
from bertopic import BERTopic
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
import plotly.graph_objects as go
from datetime import datetime
import json
from collections import deque
from datasets import load_dataset
class BERTopicChatbot:
#Initialize chatbot with a Hugging Face dataset
#dataset_name: name of the dataset on Hugging Face (e.g., 'vietnam/legal')
#text_column: name of the column containing the text data
#split: which split of the dataset to use ('train', 'test', 'validation')
#max_samples: maximum number of samples to use (to manage memory)
def __init__(self, dataset_name, text_column, split="train", max_samples=10000):
# Initialize BERT sentence transformer
self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
# Load dataset from Hugging Face
try:
dataset = load_dataset(dataset_name, split=split)
# Convert to pandas DataFrame and sample if necessary
if len(dataset) > max_samples:
dataset = dataset.shuffle(seed=42).select(range(max_samples))
self.df = dataset.to_pandas()
# Ensure text column exists
if text_column not in self.df.columns:
raise ValueError(f"Column '{text_column}' not found in dataset. Available columns: {self.df.columns}")
self.documents = self.df[text_column].tolist()
# Create and train BERTopic model
self.topic_model = BERTopic(embedding_model=self.sentence_model)
self.topics, self.probs = self.topic_model.fit_transform(self.documents)
# Create document embeddings for similarity search
self.doc_embeddings = self.sentence_model.encode(self.documents)
# Initialize metrics storage
self.metrics_history = {
'similarities': deque(maxlen=100),
'response_times': deque(maxlen=100),
'token_counts': deque(maxlen=100),
'topics_accessed': {}
}
# Store dataset info
self.dataset_info = {
'name': dataset_name,
'split': split,
'total_documents': len(self.documents),
'topics_found': len(set(self.topics))
}
except Exception as e:
st.error(f"Error loading dataset: {str(e)}")
raise
def get_metrics_visualizations(self):
"""Generate visualizations for chatbot metrics"""
# Similarity trend
fig_similarity = go.Figure()
fig_similarity.add_trace(go.Scatter(
y=list(self.metrics_history['similarities']),
mode='lines+markers',
name='Similarity Score'
))
fig_similarity.update_layout(
title='Response Similarity Trend',
yaxis_title='Similarity Score',
xaxis_title='Query Number'
)
# Response time trend
fig_response_time = go.Figure()
fig_response_time.add_trace(go.Scatter(
y=list(self.metrics_history['response_times']),
mode='lines+markers',
name='Response Time'
))
fig_response_time.update_layout(
title='Response Time Trend',
yaxis_title='Time (seconds)',
xaxis_title='Query Number'
)
# Token usage trend
fig_tokens = go.Figure()
fig_tokens.add_trace(go.Scatter(
y=list(self.metrics_history['token_counts']),
mode='lines+markers',
name='Token Count'
))
fig_tokens.update_layout(
title='Token Usage Trend',
yaxis_title='Number of Tokens',
xaxis_title='Query Number'
)
# Topics accessed pie chart
labels = list(self.metrics_history['topics_accessed'].keys())
values = list(self.metrics_history['topics_accessed'].values())
fig_topics = go.Figure(data=[go.Pie(labels=labels, values=values)])
fig_topics.update_layout(title='Topics Accessed Distribution')
# Make all figures responsive
for fig in [fig_similarity, fig_response_time, fig_tokens, fig_topics]:
fig.update_layout(
autosize=True,
margin=dict(l=20, r=20, t=40, b=20),
height=300
)
return fig_similarity, fig_response_time, fig_tokens, fig_topics
def get_most_similar_document(self, query, top_k=3):
# Encode the query
query_embedding = self.sentence_model.encode([query])[0]
# Calculate similarities
similarities = cosine_similarity([query_embedding], self.doc_embeddings)[0]
# Get top k most similar documents
top_indices = similarities.argsort()[-top_k:][::-1]
return [self.documents[i] for i in top_indices], similarities[top_indices]
def get_response(self, user_query):
try:
start_time = datetime.now()
# Get most similar documents
similar_docs, similarities = self.get_most_similar_document(user_query)
# Get topic for the query
query_topic, _ = self.topic_model.transform([user_query])
# Track topic access
topic_id = str(query_topic[0])
self.metrics_history['topics_accessed'][topic_id] = \
self.metrics_history['topics_accessed'].get(topic_id, 0) + 1
# If similarity is too low, return a default response
if max(similarities) < 0.5:
response = "Xin lỗi, tôi không có đủ thông tin để trả lời câu hỏi này một cách chính xác."
else:
response = similar_docs[0]
# Track metrics
end_time = datetime.now()
self.metrics_history['similarities'].append(float(max(similarities)))
self.metrics_history['response_times'].append((end_time - start_time).total_seconds())
self.metrics_history['token_counts'].append(len(response.split()))
metrics = {
'similarity': float(max(similarities)),
'response_time': (end_time - start_time).total_seconds(),
'tokens': len(response.split()),
'topic': topic_id
}
return response, metrics
except Exception as e:
return f"Error processing query: {str(e)}", {'error': str(e)}
def get_dataset_info(self):
#Return information about the loaded dataset and metrics
try:
return {
'dataset_info': self.dataset_info,
'metrics': {
'avg_similarity': np.mean(list(self.metrics_history['similarities'])) if self.metrics_history['similarities'] else 0,
'avg_response_time': np.mean(list(self.metrics_history['response_times'])) if self.metrics_history['response_times'] else 0,
'total_tokens': sum(self.metrics_history['token_counts']),
'topics_accessed': self.metrics_history['topics_accessed']
}
}
except Exception as e:
return {
'error': str(e),
'dataset_info': None,
'metrics': None
}
@st.cache_resource
def initialize_chatbot(dataset_name, text_column, split="train", max_samples=10000):
return BERTopicChatbot(dataset_name, text_column, split, max_samples)
def main():
st.title("🤖 Trợ Lý AI - BERTopic")
st.caption("Trò chuyện với chúng mình nhé!")
# Dataset selection sidebar
with st.sidebar:
st.header("Dataset Configuration")
dataset_name = st.text_input(
"Hugging Face Dataset Name",
value="Kanakmi/mental-disorders",
help="Enter the name of a dataset from Hugging Face (e.g., 'Kanakmi/mental-disorders')"
)
text_column = st.text_input(
"Text Column Name",
value="text",
help="Enter the name of the column containing the text data"
)
split = st.selectbox(
"Dataset Split",
options=["train", "test", "validation"],
index=0
)
max_samples = st.number_input(
"Maximum Samples",
min_value=100,
max_value=100000,
value=10000,
step=1000,
help="Maximum number of samples to load from the dataset"
)
if st.button("Load Dataset"):
with st.spinner("Loading dataset and initializing model..."):
try:
st.session_state.chatbot = initialize_chatbot(
dataset_name, text_column, split, max_samples
)
st.success("Dataset loaded successfully!")
except Exception as e:
st.error(f"Error loading dataset: {str(e)}")
# Initialize session state variables if they don't exist
if 'chatbot' not in st.session_state:
st.session_state.chatbot = None
if 'messages' not in st.session_state:
st.session_state.messages = []
# Create tabs for chat and metrics
chat_tab, metrics_tab = st.tabs(["Chat", "Metrics"])
with chat_tab:
# Display existing messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Only show chat input if chatbot is initialized
if st.session_state.chatbot is not None:
if prompt := st.chat_input("Hãy nói gì đó..."):
# Add user message
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
# Get chatbot response
response, metrics = st.session_state.chatbot.get_response(prompt)
# Add assistant response
with st.chat_message("assistant"):
st.markdown(response)
with st.expander("Response Metrics"):
st.json(metrics)
st.session_state.messages.append({"role": "assistant", "content": response})
else:
st.info("Please load a dataset first to start chatting.")
with metrics_tab:
if st.session_state.chatbot is not None:
try:
# Get visualizations from session state chatbot
fig_similarity, fig_response_time, fig_tokens, fig_topics = st.session_state.chatbot.get_metrics_visualizations()
col1, col2 = st.columns(2)
with col1:
st.plotly_chart(fig_similarity, use_container_width=True)
st.plotly_chart(fig_tokens, use_container_width=True)
with col2:
st.plotly_chart(fig_response_time, use_container_width=True)
st.plotly_chart(fig_topics, use_container_width=True)
# Display statistics
st.subheader("Overall Statistics")
metrics_history = st.session_state.chatbot.metrics_history
if len(metrics_history['similarities']) > 0:
stats_col1, stats_col2, stats_col3 = st.columns(3)
with stats_col1:
st.metric("Avg Similarity",
f"{np.mean(list(metrics_history['similarities'])):.3f}")
with stats_col2:
st.metric("Avg Response Time",
f"{np.mean(list(metrics_history['response_times'])):.3f}s")
with stats_col3:
st.metric("Total Tokens Used",
sum(metrics_history['token_counts']))
else:
st.info("No chat history available yet. Start a conversation to see metrics.")
except Exception as e:
st.error(f"Error displaying metrics: {str(e)}")
else:
st.info("Please load a dataset first to view metrics.")
if __name__ == "__main__":
main()