Spaces:
Sleeping
Sleeping
import streamlit as st | |
import asyncio | |
import sys | |
from pathlib import Path | |
import base64 | |
import pandas as pd | |
from typing import Literal, Tuple, Optional | |
from wiki import render_wiki_tab | |
from search_handlers import run_global_search, run_local_search, run_drift_search | |
import auth | |
import graphrag.api as api | |
from graphrag.config import GraphRagConfig, load_config, resolve_paths | |
from graphrag.index.create_pipeline_config import create_pipeline_config | |
from graphrag.logging import PrintProgressReporter | |
from graphrag.utils.storage import _create_storage, _load_table_from_storage | |
st.set_page_config(page_title="GraphRAG Chat Interface", page_icon="π", layout="wide") | |
# Define default avatars at the module level | |
DEFAULT_USER_AVATAR = "π€" | |
DEFAULT_BOT_AVATAR = "π€" | |
# Initialize session state for avatars | |
if "user_avatar" not in st.session_state: | |
st.session_state.user_avatar = DEFAULT_USER_AVATAR | |
if "bot_avatar" not in st.session_state: | |
st.session_state.bot_avatar = DEFAULT_BOT_AVATAR | |
# Define avatar images | |
USER_AVATAR = "π€" # Default user emoji | |
BOT_AVATAR = "π€" # Default bot emoji | |
class StreamlitProgressReporter(PrintProgressReporter): | |
def __init__(self, placeholder): | |
super().__init__("") | |
self.placeholder = placeholder | |
def success(self, message: str): | |
self.placeholder.success(message) | |
def render_chat_tab(): | |
"""Render the Chat tab content.""" | |
format_message_history() | |
# Chat input | |
if prompt := st.chat_input("Enter your query..."): | |
# Add user message to history with timestamp | |
st.session_state.messages.append( | |
{ | |
"role": "user", | |
"content": prompt, | |
"timestamp": pd.Timestamp.now().strftime("%H:%M"), | |
} | |
) | |
# Process query | |
with st.spinner("Processing your query..."): | |
response_placeholder = st.empty() | |
try: | |
if st.session_state.search_type == "global": | |
response, context = run_global_search( | |
config_filepath=st.session_state.config_filepath, | |
data_dir=st.session_state.data_dir, | |
root_dir=st.session_state.root_dir, | |
community_level=st.session_state.community_level, | |
response_type=st.session_state.response_type, | |
streaming=st.session_state.streaming, | |
query=prompt, | |
progress_placeholder=response_placeholder, | |
) | |
elif st.session_state.search_type == "drift": | |
response, context = run_drift_search( | |
config_filepath=st.session_state.config_filepath, | |
data_dir=st.session_state.data_dir, | |
root_dir=st.session_state.root_dir, | |
community_level=st.session_state.community_level, | |
response_type=st.session_state.response_type, | |
streaming=st.session_state.streaming, | |
query=prompt, | |
progress_placeholder=response_placeholder, | |
) | |
else: | |
response, context = run_local_search( | |
config_filepath=st.session_state.config_filepath, | |
data_dir=st.session_state.data_dir, | |
root_dir=st.session_state.root_dir, | |
community_level=st.session_state.community_level, | |
response_type=st.session_state.response_type, | |
streaming=st.session_state.streaming, | |
query=prompt, | |
progress_placeholder=response_placeholder, | |
) | |
# Clear the placeholder before adding the final response | |
response_placeholder.empty() | |
# Add assistant response to history with timestamp | |
st.session_state.messages.append( | |
{ | |
"role": "assistant", | |
"content": response, | |
"timestamp": pd.Timestamp.now().strftime("%H:%M"), | |
} | |
) | |
# Show context in expander | |
with st.expander("View Search Context"): | |
st.json(context) | |
except Exception as e: | |
error_message = f"Error processing query: {str(e)}" | |
st.session_state.messages.append( | |
{ | |
"role": "assistant", | |
"content": error_message, | |
"timestamp": pd.Timestamp.now().strftime("%H:%M"), | |
} | |
) | |
st.rerun() | |
def display_message(msg: str, is_user: bool = False, timestamp: str = "") -> None: | |
"""Display a chat message with avatar and consistent formatting.""" | |
role = "user" if is_user else "assistant" | |
message_class = "user-message" if is_user else "assistant-message" | |
avatar = st.session_state.user_avatar if is_user else st.session_state.bot_avatar | |
message_container = f""" | |
<div class="chat-message {message_class}"> | |
<div class="avatar"> | |
<div style="font-size: 25px; text-align: center;">{avatar}</div> | |
</div> | |
<div class="message-content-wrapper"> | |
<div class="message-bubble"> | |
<div class="message-content"> | |
{msg} | |
</div> | |
</div> | |
<div class="timestamp">{timestamp}</div> | |
</div> | |
</div> | |
""" | |
st.markdown(message_container, unsafe_allow_html=True) | |
def format_message_history() -> None: | |
"""Display all messages in the chat history with consistent formatting.""" | |
st.markdown('<div class="chat-container">', unsafe_allow_html=True) | |
for message in st.session_state.messages: | |
timestamp = message.get("timestamp", "") | |
display_message( | |
msg=message["content"], | |
is_user=(message["role"] == "user"), | |
timestamp=timestamp, | |
) | |
st.markdown("</div>", unsafe_allow_html=True) | |
def load_css(): | |
with open("styles.css", "r") as f: | |
return f.read() | |
def initialize_session_state(): | |
"""Initialize session state variables if they don't exist.""" | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "response_placeholder" not in st.session_state: | |
st.session_state.response_placeholder = None | |
if "config_filepath" not in st.session_state: | |
st.session_state.config_filepath = None | |
if "data_dir" not in st.session_state: | |
st.session_state.data_dir = None | |
if "root_dir" not in st.session_state: | |
st.session_state.root_dir = "." | |
if "community_level" not in st.session_state: | |
st.session_state.community_level = 2 | |
if "response_type" not in st.session_state: | |
st.session_state.response_type = "concise" | |
if "search_type" not in st.session_state: | |
st.session_state.search_type = "global" | |
if "streaming" not in st.session_state: | |
st.session_state.streaming = True | |
if "authenticated" not in st.session_state: | |
st.session_state.authenticated = False | |
def main(): | |
initialize_session_state() | |
# Authentication check | |
if not st.session_state.authenticated: | |
if auth.check_credentials(): | |
st.session_state.authenticated = True | |
st.rerun() # Rerun to reflect the authentication state | |
else: | |
st.stop() # Stop further execution if authentication fails | |
# If authenticated, proceed with the main app | |
if st.session_state.authenticated: | |
# Main application content | |
st.title("PWC Home Assignment #1, Graphrag") | |
css = load_css() | |
st.markdown(f"<style>{css}</style>", unsafe_allow_html=True) | |
# Sidebar configuration | |
with st.sidebar: | |
# Display logos side by side at the top of the sidebar | |
col1, col2 = st.columns(2) | |
with col1: | |
st.markdown( | |
'<div class="logo-container"><img class="logo-image" src="https://nexttech.pwc.co.il/wp-content/uploads/2023/12/image-2.png"></div>', | |
unsafe_allow_html=True, | |
) | |
with col2: | |
st.markdown( | |
'<div class="logo-container"><img class="logo-image" src="https://nexttech.pwc.co.il/wp-content/uploads/2023/12/Frame.png"></div>', | |
unsafe_allow_html=True, | |
) | |
st.header("Configuration") | |
st.session_state.community_level = st.number_input( | |
"Community Level", | |
min_value=0, | |
max_value=10, | |
value=st.session_state.community_level, | |
help="Controls the granularity of the search...", | |
) | |
# Only show response type for global and local search | |
if st.session_state.search_type != "drift": | |
st.session_state.response_type = st.selectbox( | |
"Response Type", | |
options=["concise", "detailed"], | |
index=0 if st.session_state.response_type == "concise" else 1, | |
help="Style of response generation", | |
) | |
st.session_state.search_type = st.selectbox( | |
"Search Type", | |
options=["global", "local", "drift"], | |
index=( | |
0 | |
if st.session_state.search_type == "global" | |
else 1 if st.session_state.search_type == "local" else 2 | |
), | |
help="""Search Types: | |
- Local Search: "Focuses on finding specific information by searching through direct connections in the knowledge graph. Best for precise, fact-based queries." | |
- Global Search: "Analyzes the entire document collection at a high level using community summaries. Best for understanding broad themes and general policies." | |
- DRIFT Search: "Combines local and global search capabilities, dynamically exploring connections while gathering detailed information. Best for complex queries requiring both specific details and broader context." | |
""", | |
) | |
# Show streaming option only for supported search types | |
if st.session_state.search_type != "drift": | |
st.session_state.streaming = st.checkbox( | |
"Enable Streaming", | |
value=st.session_state.streaming, | |
help="Stream response tokens as they're generated", | |
) | |
else: | |
st.session_state.streaming = False | |
st.info("Streaming is not available for DRIFT search") | |
# logout button | |
if st.button("Logout"): | |
st.session_state.clear() # Clear all session state data | |
initialize_session_state() # Reinitialize the session state | |
st.query_params = {"restart": "true"} # Refresh the UI | |
st.rerun() | |
# Create tabs | |
tab1, tab2 = st.tabs(["Assignment Documentation", "Chat"]) | |
# readme tab content | |
with tab1: | |
render_wiki_tab() | |
# Chat tab content | |
with tab2: | |
render_chat_tab() | |
st.sidebar.markdown( | |
""" | |
<div style="position: absolute; bottom: 0; width: 100%; text-align: center; font-size: 14px; margin-bottom: -200px;"> | |
Liran Baba | | |
<a href="https://linkedin.com/in/liranba" target="_blank">LinkedIn</a> | | |
<a href="https://huggingface.co/CordwainerSmith" target="_blank">HuggingFace</a> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
if __name__ == "__main__": | |
main() | |