import streamlit as st import asyncio import sys from pathlib import Path import base64 import pandas as pd from typing import Literal, Tuple, Optional from wiki import render_wiki_tab from search_handlers import run_global_search, run_local_search, run_drift_search import auth import graphrag.api as api from graphrag.config import GraphRagConfig, load_config, resolve_paths from graphrag.index.create_pipeline_config import create_pipeline_config from graphrag.logging import PrintProgressReporter from graphrag.utils.storage import _create_storage, _load_table_from_storage st.set_page_config(page_title="GraphRAG Chat Interface", page_icon="🔍", layout="wide") # Define default avatars at the module level DEFAULT_USER_AVATAR = "👤" DEFAULT_BOT_AVATAR = "🤖" # Initialize session state for avatars if "user_avatar" not in st.session_state: st.session_state.user_avatar = DEFAULT_USER_AVATAR if "bot_avatar" not in st.session_state: st.session_state.bot_avatar = DEFAULT_BOT_AVATAR # Define avatar images USER_AVATAR = "👤" # Default user emoji BOT_AVATAR = "🤖" # Default bot emoji class StreamlitProgressReporter(PrintProgressReporter): def __init__(self, placeholder): super().__init__("") self.placeholder = placeholder def success(self, message: str): self.placeholder.success(message) def render_chat_tab(): """Render the Chat tab content.""" format_message_history() # Chat input if prompt := st.chat_input("Enter your query..."): # Add user message to history with timestamp st.session_state.messages.append( { "role": "user", "content": prompt, "timestamp": pd.Timestamp.now().strftime("%H:%M"), } ) # Process query with st.spinner("Processing your query..."): response_placeholder = st.empty() try: if st.session_state.search_type == "global": response, context = run_global_search( config_filepath=st.session_state.config_filepath, data_dir=st.session_state.data_dir, root_dir=st.session_state.root_dir, community_level=st.session_state.community_level, response_type=st.session_state.response_type, streaming=st.session_state.streaming, query=prompt, progress_placeholder=response_placeholder, ) elif st.session_state.search_type == "drift": response, context = run_drift_search( config_filepath=st.session_state.config_filepath, data_dir=st.session_state.data_dir, root_dir=st.session_state.root_dir, community_level=st.session_state.community_level, response_type=st.session_state.response_type, streaming=st.session_state.streaming, query=prompt, progress_placeholder=response_placeholder, ) else: response, context = run_local_search( config_filepath=st.session_state.config_filepath, data_dir=st.session_state.data_dir, root_dir=st.session_state.root_dir, community_level=st.session_state.community_level, response_type=st.session_state.response_type, streaming=st.session_state.streaming, query=prompt, progress_placeholder=response_placeholder, ) # Clear the placeholder before adding the final response response_placeholder.empty() # Add assistant response to history with timestamp st.session_state.messages.append( { "role": "assistant", "content": response, "timestamp": pd.Timestamp.now().strftime("%H:%M"), } ) # Show context in expander with st.expander("View Search Context"): st.json(context) except Exception as e: error_message = f"Error processing query: {str(e)}" st.session_state.messages.append( { "role": "assistant", "content": error_message, "timestamp": pd.Timestamp.now().strftime("%H:%M"), } ) st.rerun() def display_message(msg: str, is_user: bool = False, timestamp: str = "") -> None: """Display a chat message with avatar and consistent formatting.""" role = "user" if is_user else "assistant" message_class = "user-message" if is_user else "assistant-message" avatar = st.session_state.user_avatar if is_user else st.session_state.bot_avatar message_container = f"""
{avatar}
{msg}
{timestamp}
""" st.markdown(message_container, unsafe_allow_html=True) def format_message_history() -> None: """Display all messages in the chat history with consistent formatting.""" st.markdown('
', unsafe_allow_html=True) for message in st.session_state.messages: timestamp = message.get("timestamp", "") display_message( msg=message["content"], is_user=(message["role"] == "user"), timestamp=timestamp, ) st.markdown("
", unsafe_allow_html=True) @st.cache_resource def load_css(): with open("styles.css", "r") as f: return f.read() def initialize_session_state(): """Initialize session state variables if they don't exist.""" if "messages" not in st.session_state: st.session_state.messages = [] if "response_placeholder" not in st.session_state: st.session_state.response_placeholder = None if "config_filepath" not in st.session_state: st.session_state.config_filepath = None if "data_dir" not in st.session_state: st.session_state.data_dir = None if "root_dir" not in st.session_state: st.session_state.root_dir = "." if "community_level" not in st.session_state: st.session_state.community_level = 2 if "response_type" not in st.session_state: st.session_state.response_type = "concise" if "search_type" not in st.session_state: st.session_state.search_type = "global" if "streaming" not in st.session_state: st.session_state.streaming = True if "authenticated" not in st.session_state: st.session_state.authenticated = False def main(): initialize_session_state() # Authentication check if not st.session_state.authenticated: if auth.check_credentials(): st.session_state.authenticated = True st.rerun() # Rerun to reflect the authentication state else: st.stop() # Stop further execution if authentication fails # If authenticated, proceed with the main app if st.session_state.authenticated: # Main application content st.title("PWC Home Assignment #1, Graphrag") css = load_css() st.markdown(f"", unsafe_allow_html=True) # Sidebar configuration with st.sidebar: # Display logos side by side at the top of the sidebar col1, col2 = st.columns(2) with col1: st.markdown( '
', unsafe_allow_html=True, ) with col2: st.markdown( '
', unsafe_allow_html=True, ) st.header("Configuration") st.session_state.community_level = st.number_input( "Community Level", min_value=0, max_value=10, value=st.session_state.community_level, help="Controls the granularity of the search...", ) # Only show response type for global and local search if st.session_state.search_type != "drift": st.session_state.response_type = st.selectbox( "Response Type", options=["concise", "detailed"], index=0 if st.session_state.response_type == "concise" else 1, help="Style of response generation", ) st.session_state.search_type = st.selectbox( "Search Type", options=["global", "local", "drift"], index=( 0 if st.session_state.search_type == "global" else 1 if st.session_state.search_type == "local" else 2 ), help="""Search Types: - Local Search: "Focuses on finding specific information by searching through direct connections in the knowledge graph. Best for precise, fact-based queries." - Global Search: "Analyzes the entire document collection at a high level using community summaries. Best for understanding broad themes and general policies." - DRIFT Search: "Combines local and global search capabilities, dynamically exploring connections while gathering detailed information. Best for complex queries requiring both specific details and broader context." """, ) # Show streaming option only for supported search types if st.session_state.search_type != "drift": st.session_state.streaming = st.checkbox( "Enable Streaming", value=st.session_state.streaming, help="Stream response tokens as they're generated", ) else: st.session_state.streaming = False st.info("Streaming is not available for DRIFT search") # logout button if st.button("Logout"): st.session_state.clear() # Clear all session state data initialize_session_state() # Reinitialize the session state st.query_params = {"restart": "true"} # Refresh the UI st.rerun() # Create tabs tab1, tab2 = st.tabs(["Assignment Documentation", "Chat"]) # readme tab content with tab1: render_wiki_tab() # Chat tab content with tab2: render_chat_tab() st.sidebar.markdown( """
Liran Baba | LinkedIn | HuggingFace
""", unsafe_allow_html=True, ) if __name__ == "__main__": main()