Spaces:
Sleeping
Sleeping
entertainment genres app
Browse files- app.py +60 -4
- requirements.txt +4 -1
app.py
CHANGED
@@ -1,9 +1,65 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
|
4 |
-
|
5 |
-
|
|
|
|
|
6 |
|
|
|
7 |
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
import onnxruntime as rt
|
6 |
+
import platform
|
7 |
|
8 |
|
9 |
+
if platform.system() == "Windows":
|
10 |
+
import pathlib
|
11 |
+
temp = pathlib.PosixPath
|
12 |
+
pathlib.PosixPath = pathlib.WindowsPath
|
13 |
|
14 |
+
model_path = "entertainment-genre-quantized.onnx"
|
15 |
|
16 |
+
with open("genre_types_encoded.json", "r") as file:
|
17 |
+
categories = json.load(file)
|
18 |
+
|
19 |
+
inf_session = rt.InferenceSession(model_path)
|
20 |
+
input_name = inf_session.get_inputs()[0].name
|
21 |
+
output_name = inf_session.get_outputs()[0].name
|
22 |
+
|
23 |
+
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
|
24 |
+
|
25 |
+
|
26 |
+
def get_top_label(cat_dict, idx):
|
27 |
+
for key, value in cat_dict.items():
|
28 |
+
if idx == value:
|
29 |
+
return key
|
30 |
+
|
31 |
+
|
32 |
+
def get_top_probs(cat_probs, idx):
|
33 |
+
return cat_probs[idx]
|
34 |
+
|
35 |
+
|
36 |
+
def entertainment_genres(description):
|
37 |
+
input_ids = tokenizer(description)['input_ids'][:512]
|
38 |
+
probs = inf_session.run([output_name], {input_name: [input_ids]})[0]
|
39 |
+
top_3_indices = sorted(range(len(probs[0])), key=lambda idx: probs[0][idx], reverse=True)[:3]
|
40 |
+
cat_prob = torch.sigmoid(torch.FloatTensor(probs))[0]
|
41 |
+
print(cat_prob)
|
42 |
+
|
43 |
+
top_labels = []
|
44 |
+
for i in top_3_indices:
|
45 |
+
top_labels.append(get_top_label(categories, i))
|
46 |
+
|
47 |
+
top_probs = []
|
48 |
+
for i in top_3_indices:
|
49 |
+
top_probs.append(get_top_probs(cat_prob, i))
|
50 |
+
|
51 |
+
return dict(zip(top_labels, map(float, top_probs)))
|
52 |
+
|
53 |
+
|
54 |
+
example = [
|
55 |
+
["March Of Soldiers is a real time strategy single player , It is a military game based on the player's skill and "
|
56 |
+
"the strength of his financial economy"],
|
57 |
+
["When the menace known as the Joker wreaks havoc and chaos on the people of Gotham, Batman must accept one of "
|
58 |
+
"the greatest psychological and physical tests of his ability to fight injustice."]
|
59 |
+
]
|
60 |
+
|
61 |
+
|
62 |
+
label = gr.outputs.Label(num_top_classes=3)
|
63 |
+
|
64 |
+
iface = gr.Interface(fn=entertainment_genres, inputs="text", outputs=label, examples=example)
|
65 |
+
iface.launch(inline=False)
|
requirements.txt
CHANGED
@@ -1 +1,4 @@
|
|
1 |
-
gradio
|
|
|
|
|
|
|
|
1 |
+
gradio==3.44.0
|
2 |
+
torch==2.0.1
|
3 |
+
transformers==4.33.1
|
4 |
+
onnxruntime==1.15.1
|