|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
|
|
model_name = "TerminatorPower/bert-news-classif-turkish" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
model.eval() |
|
|
|
|
|
reverse_label_mapping = { |
|
0: "turkiye", |
|
1: "ekonomi", |
|
2: "dunya", |
|
3: "spor", |
|
4: "magazin", |
|
5: "guncel", |
|
6: "genel", |
|
7: "siyaset", |
|
8: "saglik", |
|
9: "kultur-sanat", |
|
10: "teknoloji", |
|
11: "yasam" |
|
} |
|
|
|
|
|
def predict(text): |
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512) |
|
inputs = {key: value.to("cuda" if torch.cuda.is_available() else "cpu") for key, value in inputs.items()} |
|
model.to(inputs["input_ids"].device) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
predictions = torch.argmax(outputs.logits, dim=1) |
|
predicted_label = reverse_label_mapping[predictions.item()] |
|
return predicted_label |
|
|
|
if __name__ == "__main__": |
|
text = input() |
|
print(f"Predicted label: {predict(text)}") |
|
|