jonathanjordan21 commited on
Commit
b1a4b26
·
verified ·
1 Parent(s): 197e027

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -16
app.py CHANGED
@@ -1,6 +1,8 @@
1
  from fastapi import FastAPI
2
  import numpy as np
3
- from sentence_transformers import CrossEncoder
 
 
4
  from typing import List
5
  from pydantic import BaseModel
6
 
@@ -15,30 +17,56 @@ class InputModel(BaseModel):
15
  content: str
16
 
17
 
18
- model = CrossEncoder(
19
- # "jinaai/jina-reranker-v2-base-multilingual",
20
- "Alibaba-NLP/gte-multilingual-reranker-base",
21
- trust_remote_code=True,
 
 
 
 
 
22
  )
23
 
24
  @app.get("/")
25
  def greet_json():
26
  return {"Hello": "World!"}
27
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  @app.post("/predict_list")
29
- async def predict_list(inp : InputListModel):
30
- sentence_pairs = [[query, doc] for query,doc in zip(inp.keywords, inp.contents)]
31
- scores = model.predict(sentence_pairs, convert_to_tensor=False)#.tolist()
32
- # (-scores).argsort().tolist()
33
- return {"results":scores.tolist()}
34
 
 
 
 
35
 
36
- @app.post("/predict")
37
- async def predict(inp : InputModel):
38
- sentence_pairs = [[inp.keyword, inp.content]]
39
- scores = model.predict(sentence_pairs, convert_to_tensor=False)#.tolist()
40
- # (-scores).argsort().tolist()
41
- return {"results":scores.tolist()[0]}
 
 
 
 
 
 
 
 
42
 
43
  # keywords = model.encode(inp.keywords)
44
  # contents = model.encode(inp.contents)
 
1
  from fastapi import FastAPI
2
  import numpy as np
3
+ from sentence_transformers import CrossEncoder, SentenceTransformer
4
+ from sentence_transformers.util import cos_sim
5
+
6
  from typing import List
7
  from pydantic import BaseModel
8
 
 
17
  content: str
18
 
19
 
20
+ # model = CrossEncoder(
21
+ # # "jinaai/jina-reranker-v2-base-multilingual",
22
+ # "Alibaba-NLP/gte-multilingual-reranker-base",
23
+ # trust_remote_code=True,
24
+ # )
25
+
26
+ model = SentenceTransformer(
27
+ "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
28
+ trust_remote_code=True
29
  )
30
 
31
  @app.get("/")
32
  def greet_json():
33
  return {"Hello": "World!"}
34
 
35
+ @app.post("/predict")
36
+ async def predict(inp: InputModel):
37
+
38
+ text_emb = model.encode(inp.contents, convert_to_tensor=True)
39
+
40
+ summarize = model.encode(inp.keywords, convert_to_tensor=True)
41
+
42
+ out = (torch.nn.functional.cosine_similarity(text_emb, summarize, dim=-1) + 1)/2
43
+ # out = (cos_sim(text_emb, summarize) + 1)/2
44
+ return {"results":out.tolist()}
45
+
46
+
47
  @app.post("/predict_list")
48
+ async def predict(inp: InputListModel):
49
+ text_emb = model.encode(inp.contents, convert_to_tensor=True)
50
+ summarize = model.encode(inp.keywords, convert_to_tensor=True)
 
 
51
 
52
+ out = (torch.nn.functional.cosine_similarity(text_emb, summarize, dim=-1) + 1)/2
53
+ # out = (cos_sim(text_emb, summarize) + 1)/2
54
+ return {"results":out.tolist()}
55
 
56
+ # @app.post("/predict_list")
57
+ # async def predict_list(inp : InputListModel):
58
+ # sentence_pairs = [[query, doc] for query,doc in zip(inp.keywords, inp.contents)]
59
+ # scores = model.predict(sentence_pairs, convert_to_tensor=False)#.tolist()
60
+ # # (-scores).argsort().tolist()
61
+ # return {"results":scores.tolist()}
62
+
63
+
64
+ # @app.post("/predict")
65
+ # async def predict(inp : InputModel):
66
+ # sentence_pairs = [[inp.keyword, inp.content]]
67
+ # scores = model.predict(sentence_pairs, convert_to_tensor=False)#.tolist()
68
+ # # (-scores).argsort().tolist()
69
+ # return {"results":scores.tolist()[0]}
70
 
71
  # keywords = model.encode(inp.keywords)
72
  # contents = model.encode(inp.contents)