Spaces:
Runtime error
Runtime error
File size: 1,679 Bytes
d6a25c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
from typing import Dict
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
TextClassificationPipeline,
)
class NewsPipeline:
def __init__(self) -> None:
self.category_tokenizer = AutoTokenizer.from_pretrained("elozano/news-category")
self.category_pipeline = TextClassificationPipeline(
model=AutoModelForSequenceClassification.from_pretrained(
"elozano/news-category"
),
tokenizer=self.category_tokenizer,
)
self.fake_tokenizer = AutoTokenizer.from_pretrained("elozano/news-fake")
self.fake_pipeline = TextClassificationPipeline(
model=AutoModelForSequenceClassification.from_pretrained(
"elozano/news-fake"
),
tokenizer=self.fake_tokenizer,
)
self.clickbait_pipeline = TextClassificationPipeline(
model=AutoModelForSequenceClassification.from_pretrained(
"elozano/news-clickbait"
),
tokenizer=AutoTokenizer.from_pretrained("elozano/news-clickbait"),
)
def __call__(self, headline: str, content: str) -> Dict[str, str]:
category_article_text = f" {self.category_tokenizer.sep_token} ".join(
[headline, content]
)
fake_article_text = f" {self.fake_tokenizer.sep_token} ".join(
[headline, content]
)
return {
"category": self.category_pipeline(category_article_text)[0]["label"],
"fake": self.fake_pipeline(fake_article_text)[0]["label"],
"clickbait": self.clickbait_pipeline(headline)[0]["label"],
}
|