import gradio as gr import json import os import sys from fastapi import FastAPI from pydantic import BaseModel from hamilton import driver from pandas import DataFrame from fastapi.middleware.cors import CORSMiddleware # Add the src directory to the Python path sys.path.append(os.path.join(os.path.dirname(__file__), 'src')) from src.data_module import data_pipeline, embedding_pipeline, vectorstore from src.classification_module import semantic_similarity, dio_support_detector from src.enforcement_module import policy_enforcement_decider from decouple import config app = FastAPI() # Enable CORS for Gradio to communicate with FastAPI app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) config = { "loader": "pd", "embedding_service": "openai", "api_key": config("OPENAI_API_KEY"), "model_name": "text-embedding-ada-002", "mistral_public_url": config("MISTRAL_PUBLIC_URL"), "ner_public_url": config("NER_PUBLIC_URL"), } dr = ( driver.Builder() .with_config(config) .with_modules(data_pipeline, embedding_pipeline, vectorstore, semantic_similarity, dio_support_detector) .build() ) dr_enforcement = ( driver.Builder() .with_config(config) .with_modules(policy_enforcement_decider) .build() ) class RadicalizationDetectionRequest(BaseModel): user_text: str class PolicyEnforcementRequest(BaseModel): user_text: str violation_context: dict class RadicalizationDetectionResponse(BaseModel): values: dict class PolicyEnforcementResponse(BaseModel): values: dict @app.post("/detect_radicalization") def detect_radicalization( request: RadicalizationDetectionRequest ) -> RadicalizationDetectionResponse: results = dr.execute( final_vars=["detect_glorification"], inputs={"project_root": ".", "user_input": request.user_text} ) if isinstance(results, DataFrame): results = results.to_dict(orient="dict") return RadicalizationDetectionResponse(values=results) @app.post("/generate_policy_enforcement") def generate_policy_enforcement( request: PolicyEnforcementRequest ) -> PolicyEnforcementResponse: results = dr_enforcement.execute( final_vars=["get_enforcement_decision"], inputs={"project_root": ".", "user_input": request.user_text, "violation_context": request.violation_context} ) if isinstance(results, DataFrame): results = results.to_dict(orient="dict") return PolicyEnforcementResponse(values=results) # Gradio Interface Functions def gradio_detect_radicalization(user_text: str): request = RadicalizationDetectionRequest(user_text=user_text) response = detect_radicalization(request) return response.values def gradio_generate_policy_enforcement(user_text: str, violation_context: str): # violation_context needs to be provided in a valid JSON format try: context_dict = json.loads(violation_context) # Parse violation_context as JSON except json.JSONDecodeError: return {"error": "Invalid JSON format for violation_context"} request = PolicyEnforcementRequest(user_text=user_text, violation_context=context_dict) response = generate_policy_enforcement(request) return response.values # Define the Gradio interface iface = gr.Interface( fn=gradio_detect_radicalization, # Function to detect radicalization inputs="text", # Single text input outputs="json", # Return JSON output title="Radicalization Detection", description="Enter text to detect glorification or radicalization." ) # Second interface for policy enforcement iface2 = gr.Interface( fn=gradio_generate_policy_enforcement, # Function to generate policy enforcement inputs=["text", gr.Textbox(lines=5, placeholder="Enter JSON-formatted violation context")], # Two text inputs, one for user text, one for violation context outputs="json", # Return JSON output title="Policy Enforcement Decision", description="Enter user text and context to generate a policy enforcement decision." ) # Combine the interfaces in a Tabbed interface iface_combined = gr.TabbedInterface([iface, iface2], ["Detect Radicalization", "Policy Enforcement"]) if __name__ == "__main__": # Launch Gradio interface (no need to launch Uvicorn separately) iface_combined.launch(server_name="0.0.0.0", server_port=7860)