|
from typing import Annotated
|
|
|
|
from fastapi import FastAPI, Form, UploadFile
|
|
from pydantic import BaseModel
|
|
from hamilton import driver
|
|
from pandas import DataFrame
|
|
|
|
from data_module import data_pipeline, embedding_pipeline, vectorstore
|
|
from classification_module import semantic_similarity, dio_support_detector
|
|
from enforcement_module import policy_enforcement_decider
|
|
|
|
from decouple import config
|
|
|
|
app = FastAPI()
|
|
|
|
config = {"loader": "pd",
|
|
"embedding_service": "openai",
|
|
"api_key": config("OPENAI_API_KEY"),
|
|
"model_name": "text-embedding-ada-002",
|
|
"mistral_public_url": config("MISTRAL_PUBLIC_URL"),
|
|
"ner_public_url": config("NER_PUBLIC_URL")
|
|
}
|
|
|
|
dr = (
|
|
driver.Builder()
|
|
.with_config(config)
|
|
.with_modules(data_pipeline, embedding_pipeline, vectorstore, semantic_similarity, dio_support_detector)
|
|
.build()
|
|
)
|
|
|
|
dr_enforcement = (
|
|
driver.Builder()
|
|
.with_config(config)
|
|
.with_modules(policy_enforcement_decider)
|
|
.build()
|
|
)
|
|
|
|
class RadicalizationDetectionRequest(BaseModel):
|
|
user_text: str
|
|
|
|
class PolicyEnforcementRequest(BaseModel):
|
|
user_text: str
|
|
violation_context: dict
|
|
|
|
class RadicalizationDetectionResponse(BaseModel):
|
|
"""Response to the /detect endpoint"""
|
|
values: dict
|
|
|
|
class PolicyEnforcementResponse(BaseModel):
|
|
"""Response to the /generate_policy_enforcement endpoint"""
|
|
values: dict
|
|
|
|
@app.post("/detect_radicalization")
|
|
def detect_radicalization(
|
|
request: RadicalizationDetectionRequest
|
|
) -> RadicalizationDetectionResponse:
|
|
|
|
results = dr.execute(
|
|
final_vars=["detect_glorification"],
|
|
inputs={"project_root": ".", "user_input": request.user_text}
|
|
)
|
|
print(results)
|
|
print(type(results))
|
|
if isinstance(results, DataFrame):
|
|
results = results.to_dict(orient="dict")
|
|
return RadicalizationDetectionResponse(values=results)
|
|
|
|
@app.post("/generate_policy_enforcement")
|
|
def generate_policy_enforcement(
|
|
request: PolicyEnforcementRequest
|
|
) -> PolicyEnforcementResponse:
|
|
results = dr_enforcement.execute(
|
|
final_vars=["get_enforcement_decision"],
|
|
inputs={"project_root": ".", "user_input": request.user_text, "violation_context": request.violation_context}
|
|
)
|
|
print(results)
|
|
print(type(results))
|
|
if isinstance(results, DataFrame):
|
|
results = results.to_dict(orient="dict")
|
|
return PolicyEnforcementResponse(values=results)
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |