import streamlit as st
import asyncio
import sys
from pathlib import Path
import base64
import pandas as pd
from typing import Literal, Tuple, Optional
from wiki import render_wiki_tab
from search_handlers import run_global_search, run_local_search, run_drift_search
import auth
import graphrag.api as api
from graphrag.config import GraphRagConfig, load_config, resolve_paths
from graphrag.index.create_pipeline_config import create_pipeline_config
from graphrag.logging import PrintProgressReporter
from graphrag.utils.storage import _create_storage, _load_table_from_storage
st.set_page_config(page_title="GraphRAG Chat Interface", page_icon="🔍", layout="wide")
# Define default avatars at the module level
DEFAULT_USER_AVATAR = "👤"
DEFAULT_BOT_AVATAR = "🤖"
# Initialize session state for avatars
if "user_avatar" not in st.session_state:
st.session_state.user_avatar = DEFAULT_USER_AVATAR
if "bot_avatar" not in st.session_state:
st.session_state.bot_avatar = DEFAULT_BOT_AVATAR
# Define avatar images
USER_AVATAR = "👤" # Default user emoji
BOT_AVATAR = "🤖" # Default bot emoji
class StreamlitProgressReporter(PrintProgressReporter):
def __init__(self, placeholder):
super().__init__("")
self.placeholder = placeholder
def success(self, message: str):
self.placeholder.success(message)
def render_chat_tab():
"""Render the Chat tab content."""
format_message_history()
# Chat input
if prompt := st.chat_input("Enter your query..."):
# Add user message to history with timestamp
st.session_state.messages.append(
{
"role": "user",
"content": prompt,
"timestamp": pd.Timestamp.now().strftime("%H:%M"),
}
)
# Process query
with st.spinner("Processing your query..."):
response_placeholder = st.empty()
try:
if st.session_state.search_type == "global":
response, context = run_global_search(
config_filepath=st.session_state.config_filepath,
data_dir=st.session_state.data_dir,
root_dir=st.session_state.root_dir,
community_level=st.session_state.community_level,
response_type=st.session_state.response_type,
streaming=st.session_state.streaming,
query=prompt,
progress_placeholder=response_placeholder,
)
elif st.session_state.search_type == "drift":
response, context = run_drift_search(
config_filepath=st.session_state.config_filepath,
data_dir=st.session_state.data_dir,
root_dir=st.session_state.root_dir,
community_level=st.session_state.community_level,
response_type=st.session_state.response_type,
streaming=st.session_state.streaming,
query=prompt,
progress_placeholder=response_placeholder,
)
else:
response, context = run_local_search(
config_filepath=st.session_state.config_filepath,
data_dir=st.session_state.data_dir,
root_dir=st.session_state.root_dir,
community_level=st.session_state.community_level,
response_type=st.session_state.response_type,
streaming=st.session_state.streaming,
query=prompt,
progress_placeholder=response_placeholder,
)
# Clear the placeholder before adding the final response
response_placeholder.empty()
# Add assistant response to history with timestamp
st.session_state.messages.append(
{
"role": "assistant",
"content": response,
"timestamp": pd.Timestamp.now().strftime("%H:%M"),
}
)
# Show context in expander
with st.expander("View Search Context"):
st.json(context)
except Exception as e:
error_message = f"Error processing query: {str(e)}"
st.session_state.messages.append(
{
"role": "assistant",
"content": error_message,
"timestamp": pd.Timestamp.now().strftime("%H:%M"),
}
)
st.rerun()
def display_message(msg: str, is_user: bool = False, timestamp: str = "") -> None:
"""Display a chat message with avatar and consistent formatting."""
role = "user" if is_user else "assistant"
message_class = "user-message" if is_user else "assistant-message"
avatar = st.session_state.user_avatar if is_user else st.session_state.bot_avatar
message_container = f"""
"""
st.markdown(message_container, unsafe_allow_html=True)
def format_message_history() -> None:
"""Display all messages in the chat history with consistent formatting."""
st.markdown('', unsafe_allow_html=True)
for message in st.session_state.messages:
timestamp = message.get("timestamp", "")
display_message(
msg=message["content"],
is_user=(message["role"] == "user"),
timestamp=timestamp,
)
st.markdown("
", unsafe_allow_html=True)
@st.cache_resource
def load_css():
with open("styles.css", "r") as f:
return f.read()
def initialize_session_state():
"""Initialize session state variables if they don't exist."""
if "messages" not in st.session_state:
st.session_state.messages = []
if "response_placeholder" not in st.session_state:
st.session_state.response_placeholder = None
if "config_filepath" not in st.session_state:
st.session_state.config_filepath = None
if "data_dir" not in st.session_state:
st.session_state.data_dir = None
if "root_dir" not in st.session_state:
st.session_state.root_dir = "."
if "community_level" not in st.session_state:
st.session_state.community_level = 2
if "response_type" not in st.session_state:
st.session_state.response_type = "concise"
if "search_type" not in st.session_state:
st.session_state.search_type = "global"
if "streaming" not in st.session_state:
st.session_state.streaming = True
if "authenticated" not in st.session_state:
st.session_state.authenticated = False
def main():
initialize_session_state()
# Authentication check
if not st.session_state.authenticated:
if auth.check_credentials():
st.session_state.authenticated = True
st.rerun() # Rerun to reflect the authentication state
else:
st.stop() # Stop further execution if authentication fails
# If authenticated, proceed with the main app
if st.session_state.authenticated:
# Main application content
st.title("PWC Home Assignment #1, Graphrag")
css = load_css()
st.markdown(f"", unsafe_allow_html=True)
# Sidebar configuration
with st.sidebar:
# Display logos side by side at the top of the sidebar
col1, col2 = st.columns(2)
with col1:
st.markdown(
'',
unsafe_allow_html=True,
)
with col2:
st.markdown(
'',
unsafe_allow_html=True,
)
st.header("Configuration")
st.session_state.community_level = st.number_input(
"Community Level",
min_value=0,
max_value=10,
value=st.session_state.community_level,
help="Controls the granularity of the search...",
)
# Only show response type for global and local search
if st.session_state.search_type != "drift":
st.session_state.response_type = st.selectbox(
"Response Type",
options=["concise", "detailed"],
index=0 if st.session_state.response_type == "concise" else 1,
help="Style of response generation",
)
st.session_state.search_type = st.selectbox(
"Search Type",
options=["global", "local", "drift"],
index=(
0
if st.session_state.search_type == "global"
else 1 if st.session_state.search_type == "local" else 2
),
help="""Search Types:
- Local Search: "Focuses on finding specific information by searching through direct connections in the knowledge graph. Best for precise, fact-based queries."
- Global Search: "Analyzes the entire document collection at a high level using community summaries. Best for understanding broad themes and general policies."
- DRIFT Search: "Combines local and global search capabilities, dynamically exploring connections while gathering detailed information. Best for complex queries requiring both specific details and broader context."
""",
)
# Show streaming option only for supported search types
if st.session_state.search_type != "drift":
st.session_state.streaming = st.checkbox(
"Enable Streaming",
value=st.session_state.streaming,
help="Stream response tokens as they're generated",
)
else:
st.session_state.streaming = False
st.info("Streaming is not available for DRIFT search")
# logout button
if st.button("Logout"):
st.session_state.clear() # Clear all session state data
initialize_session_state() # Reinitialize the session state
st.query_params = {"restart": "true"} # Refresh the UI
st.rerun()
# Create tabs
tab1, tab2 = st.tabs(["Assignment Documentation", "Chat"])
# readme tab content
with tab1:
render_wiki_tab()
# Chat tab content
with tab2:
render_chat_tab()
st.sidebar.markdown(
"""
""",
unsafe_allow_html=True,
)
if __name__ == "__main__":
main()