TerminatorPower
commited on
Update predict.py
Browse files- predict.py +13 -13
predict.py
CHANGED
@@ -9,21 +9,21 @@ model.eval()
|
|
9 |
|
10 |
# Load the reverse label mapping
|
11 |
reverse_label_mapping = {
|
12 |
-
0: "
|
13 |
-
1: "
|
14 |
-
2: "
|
15 |
-
3: "
|
16 |
-
4: "
|
17 |
-
5: "
|
18 |
-
6: "
|
19 |
-
7: "
|
20 |
-
8: "
|
21 |
-
9: "
|
22 |
-
10: "
|
23 |
-
11: "
|
24 |
-
12: "siyaset" # Example: Map index 12 back to "siyaset"
|
25 |
}
|
26 |
|
|
|
27 |
def predict(text):
|
28 |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
|
29 |
inputs = {key: value.to("cuda" if torch.cuda.is_available() else "cpu") for key, value in inputs.items()}
|
|
|
9 |
|
10 |
# Load the reverse label mapping
|
11 |
reverse_label_mapping = {
|
12 |
+
0: "turkiye",
|
13 |
+
1: "ekonomi",
|
14 |
+
2: "dunya",
|
15 |
+
3: "spor",
|
16 |
+
4: "magazin",
|
17 |
+
5: "guncel",
|
18 |
+
6: "genel",
|
19 |
+
7: "siyaset",
|
20 |
+
8: "saglik",
|
21 |
+
9: "kultur-sanat",
|
22 |
+
10: "teknoloji",
|
23 |
+
11: "yasam"
|
|
|
24 |
}
|
25 |
|
26 |
+
|
27 |
def predict(text):
|
28 |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
|
29 |
inputs = {key: value.to("cuda" if torch.cuda.is_available() else "cpu") for key, value in inputs.items()}
|