Spaces:
Sleeping
Sleeping
from fastapi import FastAPI | |
from sentence_transformers import CrossEncoder, SentenceTransformer | |
from sentence_transformers.util import cos_sim | |
import torch | |
import numpy as np | |
from typing import List | |
from pydantic import BaseModel | |
app = FastAPI() | |
class InputListModel(BaseModel): | |
keywords: List[str] | |
contents: List[str] | |
class InputModel(BaseModel): | |
keyword: str | |
content: str | |
# model = CrossEncoder( | |
# # "jinaai/jina-reranker-v2-base-multilingual", | |
# "Alibaba-NLP/gte-multilingual-reranker-base", | |
# trust_remote_code=True, | |
# ) | |
model = SentenceTransformer( | |
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", | |
trust_remote_code=True | |
) | |
def greet_json(): | |
return {"Hello": "World!"} | |
async def predict(inp: InputModel): | |
text_emb = model.encode(inp.content, convert_to_tensor=True) | |
summarize = model.encode(inp.keyword, convert_to_tensor=True) | |
out = (torch.nn.functional.cosine_similarity(text_emb, summarize, dim=-1) + 1)/2 | |
# out = (cos_sim(text_emb, summarize) + 1)/2 | |
return {"results":out.tolist()} | |
async def predict_list(inp: InputListModel): | |
text_emb = model.encode(inp.contents, convert_to_tensor=True) | |
summarize = model.encode(inp.keywords, convert_to_tensor=True) | |
out = (torch.nn.functional.cosine_similarity(text_emb, summarize, dim=-1) + 1)/2 | |
# out = (cos_sim(text_emb, summarize) + 1)/2 | |
return {"results":out.tolist()} | |
# @app.post("/predict_list") | |
# async def predict_list(inp : InputListModel): | |
# sentence_pairs = [[query, doc] for query,doc in zip(inp.keywords, inp.contents)] | |
# scores = model.predict(sentence_pairs, convert_to_tensor=False)#.tolist() | |
# # (-scores).argsort().tolist() | |
# return {"results":scores.tolist()} | |
# @app.post("/predict") | |
# async def predict(inp : InputModel): | |
# sentence_pairs = [[inp.keyword, inp.content]] | |
# scores = model.predict(sentence_pairs, convert_to_tensor=False)#.tolist() | |
# # (-scores).argsort().tolist() | |
# return {"results":scores.tolist()[0]} | |
# keywords = model.encode(inp.keywords) | |
# contents = model.encode(inp.contents) | |
# return {"results":np.linalg.norm(contents-keywords).tolist()} | |