Ritvik19 commited on
Commit
38f970a
·
1 Parent(s): f1fda5b
Files changed (3) hide show
  1. README.md +1 -3
  2. pipeline.py +0 -81
  3. requirements.txt +0 -1
README.md CHANGED
@@ -3,9 +3,7 @@ language:
3
  - en
4
  tags:
5
  - sentiment-analysis
6
- - generic
7
- - text-classification
8
- library_name: generic
9
  ---
10
 
11
  ## Overview
 
3
  - en
4
  tags:
5
  - sentiment-analysis
6
+ - sklearn
 
 
7
  ---
8
 
9
  ## Overview
pipeline.py DELETED
@@ -1,81 +0,0 @@
1
- import cleantext
2
- import joblib
3
- import os
4
-
5
- class PreTrainedPipeline():
6
- def __init__(self, path) -> None:
7
- self.models = self.load_models(path)
8
-
9
- def load_models(self, path) -> dict:
10
- models = {}
11
- for class_name in [
12
- "sentiment_polarity",
13
- "opinion",
14
- "toxicity",
15
- "toxicity__hate",
16
- "toxicity__insult",
17
- "toxicity__obscene",
18
- "toxicity__sexual_explicit",
19
- "toxicity__threat",
20
- "emotion__no_emotion",
21
- "emotion__anger",
22
- "emotion__disgust",
23
- "emotion__fear",
24
- "emotion__guilt",
25
- "emotion__humour",
26
- "emotion__joy",
27
- "emotion__sadness",
28
- "emotion__shame",
29
- "emotion__surprise",
30
- ]:
31
- models[class_name] = joblib.load(
32
- os.path.join(path, f"{class_name}.bin")
33
- )
34
- return models
35
-
36
- def clean_text(self, text) -> str:
37
- return cleantext.clean(
38
- text,
39
- fix_unicode=True, # fix various unicode errors
40
- to_ascii=True, # transliterate to closest ASCII representation
41
- lower=True, # lowercase text
42
- no_line_breaks=False, # fully strip line breaks as opposed to only normalizing them
43
- no_urls=False, # replace all URLs with a special token
44
- no_emails=False, # replace all email addresses with a special token
45
- no_phone_numbers=False, # replace all phone numbers with a special token
46
- no_numbers=False, # replace all numbers with a special token
47
- no_digits=False, # replace all digits with a special token
48
- no_currency_symbols=False, # replace all currency symbols with a special token
49
- no_punct=False, # remove punctuations
50
- replace_with_punct="", # instead of removing punctuations you may replace them
51
- replace_with_url="<URL>",
52
- replace_with_email="<EMAIL>",
53
- replace_with_phone_number="<PHONE>",
54
- replace_with_number="<NUMBER>",
55
- replace_with_digit="0",
56
- replace_with_currency_symbol="<CUR>",
57
- lang="en", # set to 'de' for German special handling
58
- )
59
-
60
- def get_prediction(self, text, model, scale_min=0, scale_max=100) -> int:
61
- return round(model.predict_proba([self.clean_text(text)])[0][1] * (scale_max-scale_min) + scale_min, 2)
62
-
63
- def call(self, text):
64
- result = {}
65
- result["sentiment_polarity"] = self.get_prediction(text, self.models["sentiment_polarity"], scale_min=-100, scale_max=100)
66
- result["opinion"] = self.get_prediction(text, self.models["opinion"])
67
- result["toxicity"] = {
68
- class_name: self.get_prediction(text, model)
69
- for class_name, model in self.models.items()
70
- if class_name.startswith("toxicity")
71
- }
72
- result["emotion"] = {
73
- class_name: self.get_prediction(text, model)
74
- for class_name, model in self.models.items()
75
- if class_name.startswith("emotion")
76
- }
77
-
78
- return result
79
-
80
- def __call__(self, texts) -> dict:
81
- return [self.call(text) for text in texts]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1 +0,0 @@
1
- clean-text