|
import uvicorn |
|
import click |
|
import argparse |
|
import json |
|
from importlib import import_module |
|
from fastapi import FastAPI, WebSocket |
|
from starlette.middleware.cors import CORSMiddleware |
|
from utils import user_config, api_logger, setup_logger, RequestDataStructure |
|
|
|
|
|
|
|
total_parser = argparse.ArgumentParser("API") |
|
total_parser.add_argument("config_path", type=str) |
|
args = total_parser.parse_args() |
|
|
|
|
|
user_config.setup_config(args) |
|
|
|
|
|
setup_logger(api_logger, user_config) |
|
|
|
|
|
pipeline_class = getattr(import_module('fengshen.pipelines.' + user_config.pipeline_type), 'Pipeline') |
|
model_settings = user_config.model_settings |
|
model_args = argparse.Namespace(**model_settings) |
|
pipeline = pipeline_class( |
|
args = model_args, |
|
model = user_config.model_name |
|
) |
|
|
|
|
|
|
|
app = FastAPI( |
|
title = user_config.PROJECT_NAME, |
|
openapi_url = f"{user_config.API_PREFIX_STR}/openapi.json" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if(user_config.API_method == "POST"): |
|
@app.post(user_config.API_path, tags = user_config.API_tags) |
|
async def fengshen_post(data:RequestDataStructure): |
|
|
|
api_logger.info(data.input_text) |
|
|
|
input_text = data.input_text |
|
|
|
result = pipeline(input_text) |
|
|
|
return result |
|
else: |
|
print("only support POST method") |
|
|
|
|
|
|
|
|
|
if user_config.BACKEND_CORS_ORIGINS: |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins = [str(origin) for origin in user_config.BACKEND_CORS_ORIGINS], |
|
allow_credentials = user_config.allow_credentials, |
|
allow_methods = user_config.allow_methods, |
|
allow_headers = user_config.allow_headers, |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
uvicorn.run(app, host = user_config.SERVER_HOST, port = user_config.SERVER_PORT) |
|
|
|
|
|
|