tsgpt / src /server.py
brichett's picture
upload src folder
854f61d verified
raw
history blame
2.67 kB
from typing import Annotated
from fastapi import FastAPI, Form, UploadFile
from pydantic import BaseModel
from hamilton import driver
from pandas import DataFrame
from data_module import data_pipeline, embedding_pipeline, vectorstore
from classification_module import semantic_similarity, dio_support_detector
from enforcement_module import policy_enforcement_decider
from decouple import config
app = FastAPI()
config = {"loader": "pd",
"embedding_service": "openai",
"api_key": config("OPENAI_API_KEY"),
"model_name": "text-embedding-ada-002",
"mistral_public_url": config("MISTRAL_PUBLIC_URL"),
"ner_public_url": config("NER_PUBLIC_URL")
} # or "pd"
dr = (
driver.Builder()
.with_config(config)
.with_modules(data_pipeline, embedding_pipeline, vectorstore, semantic_similarity, dio_support_detector)
.build()
)
dr_enforcement = (
driver.Builder()
.with_config(config)
.with_modules(policy_enforcement_decider)
.build()
)
class RadicalizationDetectionRequest(BaseModel):
user_text: str
class PolicyEnforcementRequest(BaseModel):
user_text: str
violation_context: dict
class RadicalizationDetectionResponse(BaseModel):
"""Response to the /detect endpoint"""
values: dict
class PolicyEnforcementResponse(BaseModel):
"""Response to the /generate_policy_enforcement endpoint"""
values: dict
@app.post("/detect_radicalization")
def detect_radicalization(
request: RadicalizationDetectionRequest
) -> RadicalizationDetectionResponse:
results = dr.execute(
final_vars=["detect_glorification"],
inputs={"project_root": ".", "user_input": request.user_text}
)
print(results)
print(type(results))
if isinstance(results, DataFrame):
results = results.to_dict(orient="dict")
return RadicalizationDetectionResponse(values=results)
@app.post("/generate_policy_enforcement")
def generate_policy_enforcement(
request: PolicyEnforcementRequest
) -> PolicyEnforcementResponse:
results = dr_enforcement.execute(
final_vars=["get_enforcement_decision"],
inputs={"project_root": ".", "user_input": request.user_text, "violation_context": request.violation_context}
)
print(results)
print(type(results))
if isinstance(results, DataFrame):
results = results.to_dict(orient="dict")
return PolicyEnforcementResponse(values=results)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) # specify host and port