|
import gradio as gr |
|
import json |
|
import os |
|
import sys |
|
from fastapi import FastAPI |
|
from pydantic import BaseModel |
|
from hamilton import driver |
|
from pandas import DataFrame |
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), 'src')) |
|
|
|
from src.data_module import data_pipeline, embedding_pipeline, vectorstore |
|
from src.classification_module import semantic_similarity, dio_support_detector |
|
from src.enforcement_module import policy_enforcement_decider |
|
|
|
from decouple import config |
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
config = { |
|
"loader": "pd", |
|
"embedding_service": "openai", |
|
"api_key": config("OPENAI_API_KEY"), |
|
"model_name": "text-embedding-ada-002", |
|
"mistral_public_url": config("MISTRAL_PUBLIC_URL"), |
|
"ner_public_url": config("NER_PUBLIC_URL"), |
|
} |
|
|
|
dr = ( |
|
driver.Builder() |
|
.with_config(config) |
|
.with_modules(data_pipeline, embedding_pipeline, vectorstore, semantic_similarity, dio_support_detector) |
|
.build() |
|
) |
|
|
|
dr_enforcement = ( |
|
driver.Builder() |
|
.with_config(config) |
|
.with_modules(policy_enforcement_decider) |
|
.build() |
|
) |
|
|
|
class RadicalizationDetectionRequest(BaseModel): |
|
user_text: str |
|
|
|
class PolicyEnforcementRequest(BaseModel): |
|
user_text: str |
|
violation_context: dict |
|
|
|
class RadicalizationDetectionResponse(BaseModel): |
|
values: dict |
|
|
|
class PolicyEnforcementResponse(BaseModel): |
|
values: dict |
|
|
|
@app.post("/detect_radicalization") |
|
def detect_radicalization( |
|
request: RadicalizationDetectionRequest |
|
) -> RadicalizationDetectionResponse: |
|
|
|
results = dr.execute( |
|
final_vars=["detect_glorification"], |
|
inputs={"project_root": ".", "user_input": request.user_text} |
|
) |
|
if isinstance(results, DataFrame): |
|
results = results.to_dict(orient="dict") |
|
return RadicalizationDetectionResponse(values=results) |
|
|
|
@app.post("/generate_policy_enforcement") |
|
def generate_policy_enforcement( |
|
request: PolicyEnforcementRequest |
|
) -> PolicyEnforcementResponse: |
|
results = dr_enforcement.execute( |
|
final_vars=["get_enforcement_decision"], |
|
inputs={"project_root": ".", "user_input": request.user_text, "violation_context": request.violation_context} |
|
) |
|
if isinstance(results, DataFrame): |
|
results = results.to_dict(orient="dict") |
|
return PolicyEnforcementResponse(values=results) |
|
|
|
|
|
def gradio_detect_radicalization(user_text: str): |
|
request = RadicalizationDetectionRequest(user_text=user_text) |
|
response = detect_radicalization(request) |
|
return response.values |
|
|
|
def gradio_generate_policy_enforcement(user_text: str, violation_context: str): |
|
|
|
try: |
|
context_dict = json.loads(violation_context) |
|
except json.JSONDecodeError: |
|
return {"error": "Invalid JSON format for violation_context"} |
|
request = PolicyEnforcementRequest(user_text=user_text, violation_context=context_dict) |
|
response = generate_policy_enforcement(request) |
|
return response.values |
|
|
|
|
|
iface = gr.Interface( |
|
fn=gradio_detect_radicalization, |
|
inputs="text", |
|
outputs="json", |
|
title="Radicalization Detection", |
|
description="Enter text to detect glorification or radicalization." |
|
) |
|
|
|
|
|
iface2 = gr.Interface( |
|
fn=gradio_generate_policy_enforcement, |
|
inputs=["text", gr.Textbox(lines=5, placeholder="Enter JSON-formatted violation context")], |
|
outputs="json", |
|
title="Policy Enforcement Decision", |
|
description="Enter user text and context to generate a policy enforcement decision." |
|
) |
|
|
|
|
|
iface_combined = gr.TabbedInterface([iface, iface2], ["Detect Radicalization", "Policy Enforcement"]) |
|
|
|
if __name__ == "__main__": |
|
|
|
iface_combined.launch(server_name="0.0.0.0", server_port=7860) |
|
|