TerminatorPower's picture
Update predict.py
fd49f40 verified
raw
history blame
1.17 kB
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Load the model and tokenizer
model_name = "TerminatorPower/bert-news-classif-turkish"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()
# Load the reverse label mapping
reverse_label_mapping = {
0: "turkiye",
1: "ekonomi",
2: "dunya",
3: "spor",
4: "magazin",
5: "guncel",
6: "genel",
7: "siyaset",
8: "saglik",
9: "kultur-sanat",
10: "teknoloji",
11: "yasam"
}
def predict(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
inputs = {key: value.to("cuda" if torch.cuda.is_available() else "cpu") for key, value in inputs.items()}
model.to(inputs["input_ids"].device)
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=1)
predicted_label = reverse_label_mapping[predictions.item()]
return predicted_label
if __name__ == "__main__":
text = input()
print(f"Predicted label: {predict(text)}")