Ubuntu commited on
Commit
ed39e1a
·
1 Parent(s): db53170

finetuned again , converted logits to individual probablities

Browse files
Files changed (28) hide show
  1. finetuned_entity_categorical_classification/checkpoint-1606/added_tokens.json +7 -0
  2. finetuned_entity_categorical_classification/checkpoint-1606/config.json +83 -0
  3. finetuned_entity_categorical_classification/checkpoint-1606/optimizer.pt +3 -0
  4. finetuned_entity_categorical_classification/checkpoint-1606/pytorch_model.bin +3 -0
  5. finetuned_entity_categorical_classification/checkpoint-1606/rng_state.pth +0 -0
  6. finetuned_entity_categorical_classification/checkpoint-1606/scheduler.pt +3 -0
  7. finetuned_entity_categorical_classification/checkpoint-1606/special_tokens_map.json +7 -0
  8. finetuned_entity_categorical_classification/checkpoint-1606/tokenizer.json +0 -0
  9. finetuned_entity_categorical_classification/checkpoint-1606/tokenizer_config.json +56 -0
  10. finetuned_entity_categorical_classification/checkpoint-1606/trainer_state.json +46 -0
  11. finetuned_entity_categorical_classification/checkpoint-1606/training_args.bin +3 -0
  12. finetuned_entity_categorical_classification/checkpoint-1606/vocab.txt +0 -0
  13. finetuned_entity_categorical_classification/checkpoint-3212/added_tokens.json +7 -0
  14. finetuned_entity_categorical_classification/checkpoint-3212/config.json +83 -0
  15. finetuned_entity_categorical_classification/checkpoint-3212/optimizer.pt +3 -0
  16. finetuned_entity_categorical_classification/checkpoint-3212/pytorch_model.bin +3 -0
  17. finetuned_entity_categorical_classification/checkpoint-3212/rng_state.pth +0 -0
  18. finetuned_entity_categorical_classification/checkpoint-3212/scheduler.pt +3 -0
  19. finetuned_entity_categorical_classification/checkpoint-3212/special_tokens_map.json +7 -0
  20. finetuned_entity_categorical_classification/checkpoint-3212/tokenizer.json +0 -0
  21. finetuned_entity_categorical_classification/checkpoint-3212/tokenizer_config.json +56 -0
  22. finetuned_entity_categorical_classification/checkpoint-3212/trainer_state.json +73 -0
  23. finetuned_entity_categorical_classification/checkpoint-3212/training_args.bin +3 -0
  24. finetuned_entity_categorical_classification/checkpoint-3212/vocab.txt +0 -0
  25. finetuned_entity_categorical_classification/runs/Oct12_07-34-46_ip-172-31-95-165/events.out.tfevents.1697096087.ip-172-31-95-165.123522.0 +0 -0
  26. research/09_fine_tuning_for_datacategories.ipynb +187 -187
  27. research/09_inference.html +510 -223
  28. research/09_inference.ipynb +802 -212
finetuned_entity_categorical_classification/checkpoint-1606/added_tokens.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "[CLS]": 101,
3
+ "[MASK]": 103,
4
+ "[PAD]": 0,
5
+ "[SEP]": 102,
6
+ "[UNK]": 100
7
+ }
finetuned_entity_categorical_classification/checkpoint-1606/config.json ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "finetuned_entity_categorical_classification/checkpoint-23640",
3
+ "activation": "gelu",
4
+ "architectures": [
5
+ "DistilBertForSequenceClassification"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "dim": 768,
9
+ "dropout": 0.1,
10
+ "hidden_dim": 3072,
11
+ "id2label": {
12
+ "0": "Hobbies_and_Leisure",
13
+ "1": "News",
14
+ "2": "Science",
15
+ "3": "Autos_and_Vehicles",
16
+ "4": "Health",
17
+ "5": "Pets_and_Animals",
18
+ "6": "Adult",
19
+ "7": "Computers_and_Electronics",
20
+ "8": "Online Communities",
21
+ "9": "Beauty_and_Fitness",
22
+ "10": "People_and_Society",
23
+ "11": "Business_and_Industrial",
24
+ "12": "Reference",
25
+ "13": "Shopping",
26
+ "14": "Travel_and_Transportation",
27
+ "15": "Food_and_Drink",
28
+ "16": "Law_and_Government",
29
+ "17": "Books_and_Literature",
30
+ "18": "Finance",
31
+ "19": "Games",
32
+ "20": "Home_and_Garden",
33
+ "21": "Jobs_and_Education",
34
+ "22": "Arts_and_Entertainment",
35
+ "23": "Sensitive Subjects",
36
+ "24": "Real Estate",
37
+ "25": "Internet_and_Telecom",
38
+ "26": "Sports"
39
+ },
40
+ "initializer_range": 0.02,
41
+ "label2id": {
42
+ "Adult": 6,
43
+ "Arts_and_Entertainment": 22,
44
+ "Autos_and_Vehicles": 3,
45
+ "Beauty_and_Fitness": 9,
46
+ "Books_and_Literature": 17,
47
+ "Business_and_Industrial": 11,
48
+ "Computers_and_Electronics": 7,
49
+ "Finance": 18,
50
+ "Food_and_Drink": 15,
51
+ "Games": 19,
52
+ "Health": 4,
53
+ "Hobbies_and_Leisure": 0,
54
+ "Home_and_Garden": 20,
55
+ "Internet_and_Telecom": 25,
56
+ "Jobs_and_Education": 21,
57
+ "Law_and_Government": 16,
58
+ "News": 1,
59
+ "Online Communities": 8,
60
+ "People_and_Society": 10,
61
+ "Pets_and_Animals": 5,
62
+ "Real Estate": 24,
63
+ "Reference": 12,
64
+ "Science": 2,
65
+ "Sensitive Subjects": 23,
66
+ "Shopping": 13,
67
+ "Sports": 26,
68
+ "Travel_and_Transportation": 14
69
+ },
70
+ "max_position_embeddings": 512,
71
+ "model_type": "distilbert",
72
+ "n_heads": 12,
73
+ "n_layers": 6,
74
+ "pad_token_id": 0,
75
+ "problem_type": "single_label_classification",
76
+ "qa_dropout": 0.1,
77
+ "seq_classif_dropout": 0.2,
78
+ "sinusoidal_pos_embds": false,
79
+ "tie_weights_": true,
80
+ "torch_dtype": "float32",
81
+ "transformers_version": "4.34.0",
82
+ "vocab_size": 30522
83
+ }
finetuned_entity_categorical_classification/checkpoint-1606/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19acedc3d2479a0b702fa99dc2fbb3d136d6fc0d8c4d7f60c4a7801790fa7f78
3
+ size 535881018
finetuned_entity_categorical_classification/checkpoint-1606/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:412cefb52413fe32419f820fbf788fcb8f36b7ec706fa6533ae06eb5fce7a85d
3
+ size 267932842
finetuned_entity_categorical_classification/checkpoint-1606/rng_state.pth ADDED
Binary file (14.2 kB). View file
 
finetuned_entity_categorical_classification/checkpoint-1606/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8dfaaebd8d17209d079d1f5be496af774586b0a5360dfbd5dfc8c1773baeed3a
3
+ size 1064
finetuned_entity_categorical_classification/checkpoint-1606/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
finetuned_entity_categorical_classification/checkpoint-1606/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
finetuned_entity_categorical_classification/checkpoint-1606/tokenizer_config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "additional_special_tokens": [],
45
+ "clean_up_tokenization_spaces": true,
46
+ "cls_token": "[CLS]",
47
+ "do_lower_case": true,
48
+ "mask_token": "[MASK]",
49
+ "model_max_length": 512,
50
+ "pad_token": "[PAD]",
51
+ "sep_token": "[SEP]",
52
+ "strip_accents": null,
53
+ "tokenize_chinese_chars": true,
54
+ "tokenizer_class": "DistilBertTokenizer",
55
+ "unk_token": "[UNK]"
56
+ }
finetuned_entity_categorical_classification/checkpoint-1606/trainer_state.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": 0.11884114891290665,
3
+ "best_model_checkpoint": "finetuned_entity_categorical_classification/checkpoint-1606",
4
+ "epoch": 1.0,
5
+ "eval_steps": 500,
6
+ "global_step": 1606,
7
+ "is_hyper_param_search": false,
8
+ "is_local_process_zero": true,
9
+ "is_world_process_zero": true,
10
+ "log_history": [
11
+ {
12
+ "epoch": 0.31,
13
+ "learning_rate": 1.6886674968866752e-05,
14
+ "loss": 1.0674,
15
+ "step": 500
16
+ },
17
+ {
18
+ "epoch": 0.62,
19
+ "learning_rate": 1.37733499377335e-05,
20
+ "loss": 0.1399,
21
+ "step": 1000
22
+ },
23
+ {
24
+ "epoch": 0.93,
25
+ "learning_rate": 1.066002490660025e-05,
26
+ "loss": 0.1337,
27
+ "step": 1500
28
+ },
29
+ {
30
+ "epoch": 1.0,
31
+ "eval_accuracy": 0.9736842105263158,
32
+ "eval_loss": 0.11884114891290665,
33
+ "eval_runtime": 2.2458,
34
+ "eval_samples_per_second": 2859.611,
35
+ "eval_steps_per_second": 179.004,
36
+ "step": 1606
37
+ }
38
+ ],
39
+ "logging_steps": 500,
40
+ "max_steps": 3212,
41
+ "num_train_epochs": 2,
42
+ "save_steps": 500,
43
+ "total_flos": 101362033000800.0,
44
+ "trial_name": null,
45
+ "trial_params": null
46
+ }
finetuned_entity_categorical_classification/checkpoint-1606/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0adabf2f73371d63a132d200cc272c0595f2b10bd579056ad508da7aa97ef66e
3
+ size 4600
finetuned_entity_categorical_classification/checkpoint-1606/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
finetuned_entity_categorical_classification/checkpoint-3212/added_tokens.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "[CLS]": 101,
3
+ "[MASK]": 103,
4
+ "[PAD]": 0,
5
+ "[SEP]": 102,
6
+ "[UNK]": 100
7
+ }
finetuned_entity_categorical_classification/checkpoint-3212/config.json ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "finetuned_entity_categorical_classification/checkpoint-23640",
3
+ "activation": "gelu",
4
+ "architectures": [
5
+ "DistilBertForSequenceClassification"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "dim": 768,
9
+ "dropout": 0.1,
10
+ "hidden_dim": 3072,
11
+ "id2label": {
12
+ "0": "Hobbies_and_Leisure",
13
+ "1": "News",
14
+ "2": "Science",
15
+ "3": "Autos_and_Vehicles",
16
+ "4": "Health",
17
+ "5": "Pets_and_Animals",
18
+ "6": "Adult",
19
+ "7": "Computers_and_Electronics",
20
+ "8": "Online Communities",
21
+ "9": "Beauty_and_Fitness",
22
+ "10": "People_and_Society",
23
+ "11": "Business_and_Industrial",
24
+ "12": "Reference",
25
+ "13": "Shopping",
26
+ "14": "Travel_and_Transportation",
27
+ "15": "Food_and_Drink",
28
+ "16": "Law_and_Government",
29
+ "17": "Books_and_Literature",
30
+ "18": "Finance",
31
+ "19": "Games",
32
+ "20": "Home_and_Garden",
33
+ "21": "Jobs_and_Education",
34
+ "22": "Arts_and_Entertainment",
35
+ "23": "Sensitive Subjects",
36
+ "24": "Real Estate",
37
+ "25": "Internet_and_Telecom",
38
+ "26": "Sports"
39
+ },
40
+ "initializer_range": 0.02,
41
+ "label2id": {
42
+ "Adult": 6,
43
+ "Arts_and_Entertainment": 22,
44
+ "Autos_and_Vehicles": 3,
45
+ "Beauty_and_Fitness": 9,
46
+ "Books_and_Literature": 17,
47
+ "Business_and_Industrial": 11,
48
+ "Computers_and_Electronics": 7,
49
+ "Finance": 18,
50
+ "Food_and_Drink": 15,
51
+ "Games": 19,
52
+ "Health": 4,
53
+ "Hobbies_and_Leisure": 0,
54
+ "Home_and_Garden": 20,
55
+ "Internet_and_Telecom": 25,
56
+ "Jobs_and_Education": 21,
57
+ "Law_and_Government": 16,
58
+ "News": 1,
59
+ "Online Communities": 8,
60
+ "People_and_Society": 10,
61
+ "Pets_and_Animals": 5,
62
+ "Real Estate": 24,
63
+ "Reference": 12,
64
+ "Science": 2,
65
+ "Sensitive Subjects": 23,
66
+ "Shopping": 13,
67
+ "Sports": 26,
68
+ "Travel_and_Transportation": 14
69
+ },
70
+ "max_position_embeddings": 512,
71
+ "model_type": "distilbert",
72
+ "n_heads": 12,
73
+ "n_layers": 6,
74
+ "pad_token_id": 0,
75
+ "problem_type": "single_label_classification",
76
+ "qa_dropout": 0.1,
77
+ "seq_classif_dropout": 0.2,
78
+ "sinusoidal_pos_embds": false,
79
+ "tie_weights_": true,
80
+ "torch_dtype": "float32",
81
+ "transformers_version": "4.34.0",
82
+ "vocab_size": 30522
83
+ }
finetuned_entity_categorical_classification/checkpoint-3212/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:949e9674bd8f44b7f3b456d7c8866cf9e7f2f56afe03fc520788d71cc9e5d877
3
+ size 535881018
finetuned_entity_categorical_classification/checkpoint-3212/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1137fbeaa6b73c979f5792601cebf66e0fd9ed02b20d85fd1467fc78cd1e26c
3
+ size 267932842
finetuned_entity_categorical_classification/checkpoint-3212/rng_state.pth ADDED
Binary file (14.2 kB). View file
 
finetuned_entity_categorical_classification/checkpoint-3212/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8915d9bc09464457beb8a4c6765791b388089c2c9de68f8b710b52b7951ae1d9
3
+ size 1064
finetuned_entity_categorical_classification/checkpoint-3212/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
finetuned_entity_categorical_classification/checkpoint-3212/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
finetuned_entity_categorical_classification/checkpoint-3212/tokenizer_config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "additional_special_tokens": [],
45
+ "clean_up_tokenization_spaces": true,
46
+ "cls_token": "[CLS]",
47
+ "do_lower_case": true,
48
+ "mask_token": "[MASK]",
49
+ "model_max_length": 512,
50
+ "pad_token": "[PAD]",
51
+ "sep_token": "[SEP]",
52
+ "strip_accents": null,
53
+ "tokenize_chinese_chars": true,
54
+ "tokenizer_class": "DistilBertTokenizer",
55
+ "unk_token": "[UNK]"
56
+ }
finetuned_entity_categorical_classification/checkpoint-3212/trainer_state.json ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": 0.11884114891290665,
3
+ "best_model_checkpoint": "finetuned_entity_categorical_classification/checkpoint-1606",
4
+ "epoch": 2.0,
5
+ "eval_steps": 500,
6
+ "global_step": 3212,
7
+ "is_hyper_param_search": false,
8
+ "is_local_process_zero": true,
9
+ "is_world_process_zero": true,
10
+ "log_history": [
11
+ {
12
+ "epoch": 0.31,
13
+ "learning_rate": 1.6886674968866752e-05,
14
+ "loss": 1.0674,
15
+ "step": 500
16
+ },
17
+ {
18
+ "epoch": 0.62,
19
+ "learning_rate": 1.37733499377335e-05,
20
+ "loss": 0.1399,
21
+ "step": 1000
22
+ },
23
+ {
24
+ "epoch": 0.93,
25
+ "learning_rate": 1.066002490660025e-05,
26
+ "loss": 0.1337,
27
+ "step": 1500
28
+ },
29
+ {
30
+ "epoch": 1.0,
31
+ "eval_accuracy": 0.9736842105263158,
32
+ "eval_loss": 0.11884114891290665,
33
+ "eval_runtime": 2.2458,
34
+ "eval_samples_per_second": 2859.611,
35
+ "eval_steps_per_second": 179.004,
36
+ "step": 1606
37
+ },
38
+ {
39
+ "epoch": 1.25,
40
+ "learning_rate": 7.5466998754669995e-06,
41
+ "loss": 0.1071,
42
+ "step": 2000
43
+ },
44
+ {
45
+ "epoch": 1.56,
46
+ "learning_rate": 4.433374844333748e-06,
47
+ "loss": 0.0813,
48
+ "step": 2500
49
+ },
50
+ {
51
+ "epoch": 1.87,
52
+ "learning_rate": 1.3200498132004982e-06,
53
+ "loss": 0.0963,
54
+ "step": 3000
55
+ },
56
+ {
57
+ "epoch": 2.0,
58
+ "eval_accuracy": 0.9732170663344752,
59
+ "eval_loss": 0.12265542149543762,
60
+ "eval_runtime": 2.2396,
61
+ "eval_samples_per_second": 2867.448,
62
+ "eval_steps_per_second": 179.495,
63
+ "step": 3212
64
+ }
65
+ ],
66
+ "logging_steps": 500,
67
+ "max_steps": 3212,
68
+ "num_train_epochs": 2,
69
+ "save_steps": 500,
70
+ "total_flos": 202880405807352.0,
71
+ "trial_name": null,
72
+ "trial_params": null
73
+ }
finetuned_entity_categorical_classification/checkpoint-3212/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0adabf2f73371d63a132d200cc272c0595f2b10bd579056ad508da7aa97ef66e
3
+ size 4600
finetuned_entity_categorical_classification/checkpoint-3212/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
finetuned_entity_categorical_classification/runs/Oct12_07-34-46_ip-172-31-95-165/events.out.tfevents.1697096087.ip-172-31-95-165.123522.0 ADDED
Binary file (7.68 kB). View file
 
research/09_fine_tuning_for_datacategories.ipynb CHANGED
@@ -62,93 +62,93 @@
62
  " </thead>\n",
63
  " <tbody>\n",
64
  " <tr>\n",
65
- " <th>2461</th>\n",
66
- " <td>Business networking tips</td>\n",
67
- " <td>Business_and_Industrial</td>\n",
68
- " <td>12</td>\n",
69
  " </tr>\n",
70
  " <tr>\n",
71
- " <th>10003</th>\n",
72
- " <td>Industrial development and infrastructure proj...</td>\n",
73
- " <td>Business_and_Industrial</td>\n",
74
- " <td>12</td>\n",
75
  " </tr>\n",
76
  " <tr>\n",
77
- " <th>14189</th>\n",
78
- " <td>Music theory and composition discussions</td>\n",
79
- " <td>Online Communities</td>\n",
80
- " <td>20</td>\n",
81
  " </tr>\n",
82
  " <tr>\n",
83
- " <th>19874</th>\n",
84
- " <td>Civil litigation process efficiency impact inf...</td>\n",
85
- " <td>Law_and_Government</td>\n",
86
- " <td>11</td>\n",
87
  " </tr>\n",
88
  " <tr>\n",
89
- " <th>5676</th>\n",
90
- " <td>Human rights violations investigations effecti...</td>\n",
91
- " <td>Law_and_Government</td>\n",
92
  " <td>11</td>\n",
93
  " </tr>\n",
94
  " <tr>\n",
95
- " <th>27151</th>\n",
96
- " <td>Vehicle history apps</td>\n",
97
- " <td>Autos_and_Vehicles</td>\n",
98
- " <td>10</td>\n",
99
  " </tr>\n",
100
  " <tr>\n",
101
- " <th>21837</th>\n",
102
- " <td>Online references</td>\n",
103
- " <td>Reference</td>\n",
104
  " <td>25</td>\n",
105
  " </tr>\n",
106
  " <tr>\n",
107
- " <th>5541</th>\n",
108
- " <td>Gay Movies Gay</td>\n",
109
- " <td>Adult</td>\n",
110
- " <td>4</td>\n",
111
  " </tr>\n",
112
  " <tr>\n",
113
- " <th>10734</th>\n",
114
- " <td>Catfood for senior cats</td>\n",
115
- " <td>Food_and_Drink</td>\n",
116
- " <td>7</td>\n",
117
  " </tr>\n",
118
  " <tr>\n",
119
- " <th>25164</th>\n",
120
- " <td>Internet safety for sports fans</td>\n",
121
- " <td>Internet_and_Telecom</td>\n",
122
- " <td>17</td>\n",
123
  " </tr>\n",
124
  " </tbody>\n",
125
  "</table>\n",
126
  "</div>"
127
  ],
128
  "text/plain": [
129
- " category \\\n",
130
- "2461 Business networking tips \n",
131
- "10003 Industrial development and infrastructure proj... \n",
132
- "14189 Music theory and composition discussions \n",
133
- "19874 Civil litigation process efficiency impact inf... \n",
134
- "5676 Human rights violations investigations effecti... \n",
135
- "27151 Vehicle history apps \n",
136
- "21837 Online references \n",
137
- "5541 Gay Movies Gay \n",
138
- "10734 Catfood for senior cats \n",
139
- "25164 Internet safety for sports fans \n",
140
  "\n",
141
- " label label_id \n",
142
- "2461 Business_and_Industrial 12 \n",
143
- "10003 Business_and_Industrial 12 \n",
144
- "14189 Online Communities 20 \n",
145
- "19874 Law_and_Government 11 \n",
146
- "5676 Law_and_Government 11 \n",
147
- "27151 Autos_and_Vehicles 10 \n",
148
- "21837 Reference 25 \n",
149
- "5541 Adult 4 \n",
150
- "10734 Food_and_Drink 7 \n",
151
- "25164 Internet_and_Telecom 17 "
152
  ]
153
  },
154
  "execution_count": 3,
@@ -196,40 +196,40 @@
196
  " <tbody>\n",
197
  " <tr>\n",
198
  " <th>0</th>\n",
199
- " <td>Pet nutrition consulting for small mammal spec...</td>\n",
200
- " <td>26</td>\n",
201
  " </tr>\n",
202
  " <tr>\n",
203
  " <th>1</th>\n",
204
- " <td>Makeup for mature skin</td>\n",
205
- " <td>24</td>\n",
206
  " </tr>\n",
207
  " <tr>\n",
208
  " <th>2</th>\n",
209
- " <td>Volunteer opportunities near me</td>\n",
210
- " <td>1</td>\n",
211
  " </tr>\n",
212
  " <tr>\n",
213
  " <th>3</th>\n",
214
- " <td>Financial responsibility for college graduates</td>\n",
215
- " <td>21</td>\n",
216
  " </tr>\n",
217
  " <tr>\n",
218
  " <th>4</th>\n",
219
- " <td>Distance learning</td>\n",
220
- " <td>19</td>\n",
221
  " </tr>\n",
222
  " </tbody>\n",
223
  "</table>\n",
224
  "</div>"
225
  ],
226
  "text/plain": [
227
- " category label_id\n",
228
- "0 Pet nutrition consulting for small mammal spec... 26\n",
229
- "1 Makeup for mature skin 24\n",
230
- "2 Volunteer opportunities near me 1\n",
231
- "3 Financial responsibility for college graduates 21\n",
232
- "4 Distance learning 19"
233
  ]
234
  },
235
  "execution_count": 4,
@@ -250,8 +250,8 @@
250
  {
251
  "data": {
252
  "text/plain": [
253
- "False 20792\n",
254
- "True 11038\n",
255
  "Name: count, dtype: int64"
256
  ]
257
  },
@@ -273,7 +273,7 @@
273
  "name": "stderr",
274
  "output_type": "stream",
275
  "text": [
276
- "/tmp/ipykernel_122572/984288843.py:1: SettingWithCopyWarning: \n",
277
  "A value is trying to be set on a copy of a slice from a DataFrame\n",
278
  "\n",
279
  "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
@@ -307,54 +307,54 @@
307
  " </thead>\n",
308
  " <tbody>\n",
309
  " <tr>\n",
310
- " <th>18501</th>\n",
311
- " <td>Vegan diet myths</td>\n",
312
- " <td>24</td>\n",
313
  " </tr>\n",
314
  " <tr>\n",
315
- " <th>5596</th>\n",
316
- " <td>Game industry news</td>\n",
317
- " <td>8</td>\n",
318
  " </tr>\n",
319
  " <tr>\n",
320
- " <th>31812</th>\n",
321
- " <td>Sports Team Fan Support</td>\n",
322
- " <td>5</td>\n",
323
  " </tr>\n",
324
  " <tr>\n",
325
- " <th>31249</th>\n",
326
- " <td>free Granny</td>\n",
327
- " <td>4</td>\n",
328
  " </tr>\n",
329
  " <tr>\n",
330
- " <th>19536</th>\n",
331
- " <td>Travel destination monastic retreats</td>\n",
332
- " <td>2</td>\n",
333
  " </tr>\n",
334
  " <tr>\n",
335
- " <th>29460</th>\n",
336
- " <td>Sports statistics</td>\n",
337
- " <td>8</td>\n",
338
  " </tr>\n",
339
  " <tr>\n",
340
- " <th>12554</th>\n",
341
- " <td>Online payment systems</td>\n",
342
- " <td>16</td>\n",
343
  " </tr>\n",
344
  " <tr>\n",
345
- " <th>26502</th>\n",
346
- " <td>eSports Game Esports Player Fan Engagement Ini...</td>\n",
347
- " <td>5</td>\n",
348
  " </tr>\n",
349
  " <tr>\n",
350
- " <th>24910</th>\n",
351
- " <td>Financial empowerment strategies for empowerment</td>\n",
352
- " <td>21</td>\n",
353
  " </tr>\n",
354
  " <tr>\n",
355
- " <th>20072</th>\n",
356
- " <td>Kickboxing gloves</td>\n",
357
- " <td>0</td>\n",
358
  " </tr>\n",
359
  " </tbody>\n",
360
  "</table>\n",
@@ -362,16 +362,16 @@
362
  ],
363
  "text/plain": [
364
  " text label\n",
365
- "18501 Vegan diet myths 24\n",
366
- "5596 Game industry news 8\n",
367
- "31812 Sports Team Fan Support 5\n",
368
- "31249 free Granny 4\n",
369
- "19536 Travel destination monastic retreats 2\n",
370
- "29460 Sports statistics 8\n",
371
- "12554 Online payment systems 16\n",
372
- "26502 eSports Game Esports Player Fan Engagement Ini... 5\n",
373
- "24910 Financial empowerment strategies for empowerment 21\n",
374
- "20072 Kickboxing gloves 0"
375
  ]
376
  },
377
  "execution_count": 6,
@@ -409,7 +409,7 @@
409
  "text/plain": [
410
  "Dataset({\n",
411
  " features: ['text', 'label'],\n",
412
- " num_rows: 31830\n",
413
  "})"
414
  ]
415
  },
@@ -434,11 +434,11 @@
434
  "DatasetDict({\n",
435
  " train: Dataset({\n",
436
  " features: ['text', 'label'],\n",
437
- " num_rows: 25464\n",
438
  " })\n",
439
  " test: Dataset({\n",
440
  " features: ['text', 'label'],\n",
441
- " num_rows: 6366\n",
442
  " })\n",
443
  "})"
444
  ]
@@ -483,8 +483,8 @@
483
  "name": "stderr",
484
  "output_type": "stream",
485
  "text": [
486
- "Map: 100%|██████████| 25464/25464 [00:00<00:00, 33634.70 examples/s]\n",
487
- "Map: 100%|██████████| 6366/6366 [00:00<00:00, 37230.41 examples/s]\n"
488
  ]
489
  }
490
  ],
@@ -501,9 +501,9 @@
501
  "name": "stderr",
502
  "output_type": "stream",
503
  "text": [
504
- "2023-10-12 07:17:13.000135: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
505
  "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
506
- "2023-10-12 07:17:14.376613: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
507
  ]
508
  }
509
  ],
@@ -563,33 +563,33 @@
563
  {
564
  "data": {
565
  "text/plain": [
566
- "{'Beauty_and_Fitness': 0,\n",
567
- " 'People_and_Society': 1,\n",
568
- " 'Travel_and_Transportation': 2,\n",
569
- " 'Shopping': 3,\n",
570
- " 'Adult': 4,\n",
571
- " 'Sports': 5,\n",
572
- " 'Science': 6,\n",
573
- " 'Food_and_Drink': 7,\n",
574
- " 'News': 8,\n",
575
- " 'Sensitive Subjects': 9,\n",
576
- " 'Autos_and_Vehicles': 10,\n",
577
- " 'Law_and_Government': 11,\n",
578
- " 'Business_and_Industrial': 12,\n",
579
- " 'Health': 13,\n",
580
- " 'Real Estate': 14,\n",
581
- " 'Books_and_Literature': 15,\n",
582
- " 'Computers_and_Electronics': 16,\n",
583
- " 'Internet_and_Telecom': 17,\n",
584
- " 'Home_and_Garden': 18,\n",
585
- " 'Jobs_and_Education': 19,\n",
586
- " 'Online Communities': 20,\n",
587
- " 'Finance': 21,\n",
588
  " 'Arts_and_Entertainment': 22,\n",
589
- " 'Games': 23,\n",
590
- " 'Hobbies_and_Leisure': 24,\n",
591
- " 'Reference': 25,\n",
592
- " 'Pets_and_Animals': 26}"
593
  ]
594
  },
595
  "execution_count": 16,
@@ -612,33 +612,33 @@
612
  {
613
  "data": {
614
  "text/plain": [
615
- "{0: 'Beauty_and_Fitness',\n",
616
- " 1: 'People_and_Society',\n",
617
- " 2: 'Travel_and_Transportation',\n",
618
- " 3: 'Shopping',\n",
619
- " 4: 'Adult',\n",
620
- " 5: 'Sports',\n",
621
- " 6: 'Science',\n",
622
- " 7: 'Food_and_Drink',\n",
623
- " 8: 'News',\n",
624
- " 9: 'Sensitive Subjects',\n",
625
- " 10: 'Autos_and_Vehicles',\n",
626
- " 11: 'Law_and_Government',\n",
627
- " 12: 'Business_and_Industrial',\n",
628
- " 13: 'Health',\n",
629
- " 14: 'Real Estate',\n",
630
- " 15: 'Books_and_Literature',\n",
631
- " 16: 'Computers_and_Electronics',\n",
632
- " 17: 'Internet_and_Telecom',\n",
633
- " 18: 'Home_and_Garden',\n",
634
- " 19: 'Jobs_and_Education',\n",
635
- " 20: 'Online Communities',\n",
636
- " 21: 'Finance',\n",
637
  " 22: 'Arts_and_Entertainment',\n",
638
- " 23: 'Games',\n",
639
- " 24: 'Hobbies_and_Leisure',\n",
640
- " 25: 'Reference',\n",
641
- " 26: 'Pets_and_Animals'}"
642
  ]
643
  },
644
  "execution_count": 17,
@@ -685,8 +685,8 @@
685
  "\n",
686
  " <div>\n",
687
  " \n",
688
- " <progress value='3184' max='3184' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
689
- " [3184/3184 01:41, Epoch 2/2]\n",
690
  " </div>\n",
691
  " <table border=\"1\" class=\"dataframe\">\n",
692
  " <thead>\n",
@@ -700,15 +700,15 @@
700
  " <tbody>\n",
701
  " <tr>\n",
702
  " <td>1</td>\n",
703
- " <td>0.098900</td>\n",
704
- " <td>0.105882</td>\n",
705
- " <td>0.972196</td>\n",
706
  " </tr>\n",
707
  " <tr>\n",
708
  " <td>2</td>\n",
709
- " <td>0.060600</td>\n",
710
- " <td>0.105043</td>\n",
711
- " <td>0.973296</td>\n",
712
  " </tr>\n",
713
  " </tbody>\n",
714
  "</table><p>"
@@ -723,7 +723,7 @@
723
  {
724
  "data": {
725
  "text/plain": [
726
- "TrainOutput(global_step=3184, training_loss=0.08544207278208517, metrics={'train_runtime': 102.0004, 'train_samples_per_second': 499.292, 'train_steps_per_second': 31.216, 'total_flos': 204535951167600.0, 'train_loss': 0.08544207278208517, 'epoch': 2.0})"
727
  ]
728
  },
729
  "execution_count": 19,
 
62
  " </thead>\n",
63
  " <tbody>\n",
64
  " <tr>\n",
65
+ " <th>30126</th>\n",
66
+ " <td>Farmers market products</td>\n",
67
+ " <td>Food_and_Drink</td>\n",
68
+ " <td>15</td>\n",
69
  " </tr>\n",
70
  " <tr>\n",
71
+ " <th>14239</th>\n",
72
+ " <td>Political rallies</td>\n",
73
+ " <td>News</td>\n",
74
+ " <td>1</td>\n",
75
  " </tr>\n",
76
  " <tr>\n",
77
+ " <th>20410</th>\n",
78
+ " <td>Diversity celebrations</td>\n",
79
+ " <td>People_and_Society</td>\n",
80
+ " <td>10</td>\n",
81
  " </tr>\n",
82
  " <tr>\n",
83
+ " <th>1446</th>\n",
84
+ " <td>Remote work and remote project management</td>\n",
85
+ " <td>Jobs_and_Education</td>\n",
86
+ " <td>21</td>\n",
87
  " </tr>\n",
88
  " <tr>\n",
89
+ " <th>6985</th>\n",
90
+ " <td>Industrial equipment suppliers</td>\n",
91
+ " <td>Business_and_Industrial</td>\n",
92
  " <td>11</td>\n",
93
  " </tr>\n",
94
  " <tr>\n",
95
+ " <th>30906</th>\n",
96
+ " <td>Guided sleep meditation</td>\n",
97
+ " <td>Beauty_and_Fitness</td>\n",
98
+ " <td>9</td>\n",
99
  " </tr>\n",
100
  " <tr>\n",
101
+ " <th>4351</th>\n",
102
+ " <td>VPN for business</td>\n",
103
+ " <td>Internet_and_Telecom</td>\n",
104
  " <td>25</td>\n",
105
  " </tr>\n",
106
  " <tr>\n",
107
+ " <th>8599</th>\n",
108
+ " <td>Razer Kraken ear cushions</td>\n",
109
+ " <td>Computers_and_Electronics</td>\n",
110
+ " <td>7</td>\n",
111
  " </tr>\n",
112
  " <tr>\n",
113
+ " <th>28322</th>\n",
114
+ " <td>Citation context organization platforms</td>\n",
115
+ " <td>Reference</td>\n",
116
+ " <td>12</td>\n",
117
  " </tr>\n",
118
  " <tr>\n",
119
+ " <th>5368</th>\n",
120
+ " <td>Quality Porn Videos</td>\n",
121
+ " <td>Adult</td>\n",
122
+ " <td>6</td>\n",
123
  " </tr>\n",
124
  " </tbody>\n",
125
  "</table>\n",
126
  "</div>"
127
  ],
128
  "text/plain": [
129
+ " category label \\\n",
130
+ "30126 Farmers market products Food_and_Drink \n",
131
+ "14239 Political rallies News \n",
132
+ "20410 Diversity celebrations People_and_Society \n",
133
+ "1446 Remote work and remote project management Jobs_and_Education \n",
134
+ "6985 Industrial equipment suppliers Business_and_Industrial \n",
135
+ "30906 Guided sleep meditation Beauty_and_Fitness \n",
136
+ "4351 VPN for business Internet_and_Telecom \n",
137
+ "8599 Razer Kraken ear cushions Computers_and_Electronics \n",
138
+ "28322 Citation context organization platforms Reference \n",
139
+ "5368 Quality Porn Videos Adult \n",
140
  "\n",
141
+ " label_id \n",
142
+ "30126 15 \n",
143
+ "14239 1 \n",
144
+ "20410 10 \n",
145
+ "1446 21 \n",
146
+ "6985 11 \n",
147
+ "30906 9 \n",
148
+ "4351 25 \n",
149
+ "8599 7 \n",
150
+ "28322 12 \n",
151
+ "5368 6 "
152
  ]
153
  },
154
  "execution_count": 3,
 
196
  " <tbody>\n",
197
  " <tr>\n",
198
  " <th>0</th>\n",
199
+ " <td>DIY woodworking projects</td>\n",
200
+ " <td>20</td>\n",
201
  " </tr>\n",
202
  " <tr>\n",
203
  " <th>1</th>\n",
204
+ " <td>Music festivals lineup leaks</td>\n",
205
+ " <td>22</td>\n",
206
  " </tr>\n",
207
  " <tr>\n",
208
  " <th>2</th>\n",
209
+ " <td>Sports Team Fan Love</td>\n",
210
+ " <td>26</td>\n",
211
  " </tr>\n",
212
  " <tr>\n",
213
  " <th>3</th>\n",
214
+ " <td>Food portion control and portion control apps</td>\n",
215
+ " <td>15</td>\n",
216
  " </tr>\n",
217
  " <tr>\n",
218
  " <th>4</th>\n",
219
+ " <td>Planting flower beds</td>\n",
220
+ " <td>20</td>\n",
221
  " </tr>\n",
222
  " </tbody>\n",
223
  "</table>\n",
224
  "</div>"
225
  ],
226
  "text/plain": [
227
+ " category label_id\n",
228
+ "0 DIY woodworking projects 20\n",
229
+ "1 Music festivals lineup leaks 22\n",
230
+ "2 Sports Team Fan Love 26\n",
231
+ "3 Food portion control and portion control apps 15\n",
232
+ "4 Planting flower beds 20"
233
  ]
234
  },
235
  "execution_count": 4,
 
250
  {
251
  "data": {
252
  "text/plain": [
253
+ "False 21064\n",
254
+ "True 11044\n",
255
  "Name: count, dtype: int64"
256
  ]
257
  },
 
273
  "name": "stderr",
274
  "output_type": "stream",
275
  "text": [
276
+ "/tmp/ipykernel_123522/984288843.py:1: SettingWithCopyWarning: \n",
277
  "A value is trying to be set on a copy of a slice from a DataFrame\n",
278
  "\n",
279
  "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
 
307
  " </thead>\n",
308
  " <tbody>\n",
309
  " <tr>\n",
310
+ " <th>22892</th>\n",
311
+ " <td>Business data analysis tools</td>\n",
312
+ " <td>11</td>\n",
313
  " </tr>\n",
314
  " <tr>\n",
315
+ " <th>26952</th>\n",
316
+ " <td>Movie posters minimalist iconic film symbols a...</td>\n",
317
+ " <td>22</td>\n",
318
  " </tr>\n",
319
  " <tr>\n",
320
+ " <th>27699</th>\n",
321
+ " <td>Sports Team Fan Parties</td>\n",
322
+ " <td>26</td>\n",
323
  " </tr>\n",
324
  " <tr>\n",
325
+ " <th>6288</th>\n",
326
+ " <td>Collectible vintage items and antiques</td>\n",
327
+ " <td>13</td>\n",
328
  " </tr>\n",
329
  " <tr>\n",
330
+ " <th>22173</th>\n",
331
+ " <td>Skin rejuvenation treatments and procedures</td>\n",
332
+ " <td>9</td>\n",
333
  " </tr>\n",
334
  " <tr>\n",
335
+ " <th>13124</th>\n",
336
+ " <td>Poetry analysis guidelines</td>\n",
337
+ " <td>22</td>\n",
338
  " </tr>\n",
339
  " <tr>\n",
340
+ " <th>20269</th>\n",
341
+ " <td>Health Education for Men</td>\n",
342
+ " <td>4</td>\n",
343
  " </tr>\n",
344
  " <tr>\n",
345
+ " <th>10112</th>\n",
346
+ " <td>MacBook Pro Ports</td>\n",
347
+ " <td>7</td>\n",
348
  " </tr>\n",
349
  " <tr>\n",
350
+ " <th>31312</th>\n",
351
+ " <td>Mixology equipment for home bartenders and mix...</td>\n",
352
+ " <td>15</td>\n",
353
  " </tr>\n",
354
  " <tr>\n",
355
+ " <th>30209</th>\n",
356
+ " <td>Poetry analysis examples with explanations</td>\n",
357
+ " <td>22</td>\n",
358
  " </tr>\n",
359
  " </tbody>\n",
360
  "</table>\n",
 
362
  ],
363
  "text/plain": [
364
  " text label\n",
365
+ "22892 Business data analysis tools 11\n",
366
+ "26952 Movie posters minimalist iconic film symbols a... 22\n",
367
+ "27699 Sports Team Fan Parties 26\n",
368
+ "6288 Collectible vintage items and antiques 13\n",
369
+ "22173 Skin rejuvenation treatments and procedures 9\n",
370
+ "13124 Poetry analysis guidelines 22\n",
371
+ "20269 Health Education for Men 4\n",
372
+ "10112 MacBook Pro Ports 7\n",
373
+ "31312 Mixology equipment for home bartenders and mix... 15\n",
374
+ "30209 Poetry analysis examples with explanations 22"
375
  ]
376
  },
377
  "execution_count": 6,
 
409
  "text/plain": [
410
  "Dataset({\n",
411
  " features: ['text', 'label'],\n",
412
+ " num_rows: 32108\n",
413
  "})"
414
  ]
415
  },
 
434
  "DatasetDict({\n",
435
  " train: Dataset({\n",
436
  " features: ['text', 'label'],\n",
437
+ " num_rows: 25686\n",
438
  " })\n",
439
  " test: Dataset({\n",
440
  " features: ['text', 'label'],\n",
441
+ " num_rows: 6422\n",
442
  " })\n",
443
  "})"
444
  ]
 
483
  "name": "stderr",
484
  "output_type": "stream",
485
  "text": [
486
+ "Map: 100%|██████████| 25686/25686 [00:00<00:00, 33313.88 examples/s]\n",
487
+ "Map: 100%|██████████| 6422/6422 [00:00<00:00, 41958.07 examples/s]\n"
488
  ]
489
  }
490
  ],
 
501
  "name": "stderr",
502
  "output_type": "stream",
503
  "text": [
504
+ "2023-10-12 07:34:40.359726: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
505
  "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
506
+ "2023-10-12 07:34:41.887700: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
507
  ]
508
  }
509
  ],
 
563
  {
564
  "data": {
565
  "text/plain": [
566
+ "{'Hobbies_and_Leisure': 0,\n",
567
+ " 'News': 1,\n",
568
+ " 'Science': 2,\n",
569
+ " 'Autos_and_Vehicles': 3,\n",
570
+ " 'Health': 4,\n",
571
+ " 'Pets_and_Animals': 5,\n",
572
+ " 'Adult': 6,\n",
573
+ " 'Computers_and_Electronics': 7,\n",
574
+ " 'Online Communities': 8,\n",
575
+ " 'Beauty_and_Fitness': 9,\n",
576
+ " 'People_and_Society': 10,\n",
577
+ " 'Business_and_Industrial': 11,\n",
578
+ " 'Reference': 12,\n",
579
+ " 'Shopping': 13,\n",
580
+ " 'Travel_and_Transportation': 14,\n",
581
+ " 'Food_and_Drink': 15,\n",
582
+ " 'Law_and_Government': 16,\n",
583
+ " 'Books_and_Literature': 17,\n",
584
+ " 'Finance': 18,\n",
585
+ " 'Games': 19,\n",
586
+ " 'Home_and_Garden': 20,\n",
587
+ " 'Jobs_and_Education': 21,\n",
588
  " 'Arts_and_Entertainment': 22,\n",
589
+ " 'Sensitive Subjects': 23,\n",
590
+ " 'Real Estate': 24,\n",
591
+ " 'Internet_and_Telecom': 25,\n",
592
+ " 'Sports': 26}"
593
  ]
594
  },
595
  "execution_count": 16,
 
612
  {
613
  "data": {
614
  "text/plain": [
615
+ "{0: 'Hobbies_and_Leisure',\n",
616
+ " 1: 'News',\n",
617
+ " 2: 'Science',\n",
618
+ " 3: 'Autos_and_Vehicles',\n",
619
+ " 4: 'Health',\n",
620
+ " 5: 'Pets_and_Animals',\n",
621
+ " 6: 'Adult',\n",
622
+ " 7: 'Computers_and_Electronics',\n",
623
+ " 8: 'Online Communities',\n",
624
+ " 9: 'Beauty_and_Fitness',\n",
625
+ " 10: 'People_and_Society',\n",
626
+ " 11: 'Business_and_Industrial',\n",
627
+ " 12: 'Reference',\n",
628
+ " 13: 'Shopping',\n",
629
+ " 14: 'Travel_and_Transportation',\n",
630
+ " 15: 'Food_and_Drink',\n",
631
+ " 16: 'Law_and_Government',\n",
632
+ " 17: 'Books_and_Literature',\n",
633
+ " 18: 'Finance',\n",
634
+ " 19: 'Games',\n",
635
+ " 20: 'Home_and_Garden',\n",
636
+ " 21: 'Jobs_and_Education',\n",
637
  " 22: 'Arts_and_Entertainment',\n",
638
+ " 23: 'Sensitive Subjects',\n",
639
+ " 24: 'Real Estate',\n",
640
+ " 25: 'Internet_and_Telecom',\n",
641
+ " 26: 'Sports'}"
642
  ]
643
  },
644
  "execution_count": 17,
 
685
  "\n",
686
  " <div>\n",
687
  " \n",
688
+ " <progress value='3212' max='3212' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
689
+ " [3212/3212 01:44, Epoch 2/2]\n",
690
  " </div>\n",
691
  " <table border=\"1\" class=\"dataframe\">\n",
692
  " <thead>\n",
 
700
  " <tbody>\n",
701
  " <tr>\n",
702
  " <td>1</td>\n",
703
+ " <td>0.133700</td>\n",
704
+ " <td>0.118841</td>\n",
705
+ " <td>0.973684</td>\n",
706
  " </tr>\n",
707
  " <tr>\n",
708
  " <td>2</td>\n",
709
+ " <td>0.096300</td>\n",
710
+ " <td>0.122655</td>\n",
711
+ " <td>0.973217</td>\n",
712
  " </tr>\n",
713
  " </tbody>\n",
714
  "</table><p>"
 
723
  {
724
  "data": {
725
  "text/plain": [
726
+ "TrainOutput(global_step=3212, training_loss=0.2577073345445607, metrics={'train_runtime': 105.4831, 'train_samples_per_second': 487.016, 'train_steps_per_second': 30.45, 'total_flos': 202880405807352.0, 'train_loss': 0.2577073345445607, 'epoch': 2.0})"
727
  ]
728
  },
729
  "execution_count": 19,
research/09_inference.html CHANGED
@@ -7475,7 +7475,7 @@ a.anchor-link {
7475
  </style>
7476
  <!-- End of mermaid configuration --></head>
7477
  <body class="jp-Notebook" data-jp-theme-light="true" data-jp-theme-name="JupyterLab Light">
7478
- <main><div class="jp-Cell jp-CodeCell jp-Notebook-cell">
7479
  <div class="jp-Cell-inputWrapper" tabindex="0">
7480
  <div class="jp-Collapser jp-InputCollapser jp-Cell-inputCollapser">
7481
  </div>
@@ -7492,25 +7492,7 @@ a.anchor-link {
7492
  </div>
7493
  </div>
7494
  </div>
7495
- <div class="jp-Cell-outputWrapper">
7496
- <div class="jp-Collapser jp-OutputCollapser jp-Cell-outputCollapser">
7497
- </div>
7498
- <div class="jp-OutputArea jp-Cell-outputArea">
7499
- <div class="jp-OutputArea-child">
7500
- <div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
7501
- <div class="jp-RenderedText jp-OutputArea-output" data-mime-type="application/vnd.jupyter.stderr" tabindex="0">
7502
- <pre>/home/ubuntu/SentenceStructureComparision/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
7503
- from .autonotebook import tqdm as notebook_tqdm
7504
- 2023-10-12 05:59:27.575495: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
7505
- To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
7506
- 2023-10-12 05:59:28.314367: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
7507
- Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
7508
- </pre>
7509
- </div>
7510
- </div>
7511
- </div>
7512
- </div>
7513
- </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell">
7514
  <div class="jp-Cell-inputWrapper" tabindex="0">
7515
  <div class="jp-Collapser jp-InputCollapser jp-Cell-inputCollapser">
7516
  </div>
@@ -7526,19 +7508,7 @@ Special tokens have been added in the vocabulary, make sure the associated word
7526
  </div>
7527
  </div>
7528
  </div>
7529
- <div class="jp-Cell-outputWrapper">
7530
- <div class="jp-Collapser jp-OutputCollapser jp-Cell-outputCollapser">
7531
- </div>
7532
- <div class="jp-OutputArea jp-Cell-outputArea">
7533
- <div class="jp-OutputArea-child jp-OutputArea-executeResult">
7534
- <div class="jp-OutputPrompt jp-OutputArea-prompt">Out[ ]:</div>
7535
- <div class="jp-RenderedText jp-OutputArea-output jp-OutputArea-executeResult" data-mime-type="text/plain" tabindex="0">
7536
- <pre>[{'label': 'Computers_and_Electronics', 'score': 0.9999090433120728}]</pre>
7537
- </div>
7538
- </div>
7539
- </div>
7540
- </div>
7541
- </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell">
7542
  <div class="jp-Cell-inputWrapper" tabindex="0">
7543
  <div class="jp-Collapser jp-InputCollapser jp-Cell-inputCollapser">
7544
  </div>
@@ -7554,19 +7524,7 @@ Special tokens have been added in the vocabulary, make sure the associated word
7554
  </div>
7555
  </div>
7556
  </div>
7557
- <div class="jp-Cell-outputWrapper">
7558
- <div class="jp-Collapser jp-OutputCollapser jp-Cell-outputCollapser">
7559
- </div>
7560
- <div class="jp-OutputArea jp-Cell-outputArea">
7561
- <div class="jp-OutputArea-child jp-OutputArea-executeResult">
7562
- <div class="jp-OutputPrompt jp-OutputArea-prompt">Out[ ]:</div>
7563
- <div class="jp-RenderedText jp-OutputArea-output jp-OutputArea-executeResult" data-mime-type="text/plain" tabindex="0">
7564
- <pre>[{'label': 'Health', 'score': 0.49160146713256836}]</pre>
7565
- </div>
7566
- </div>
7567
- </div>
7568
- </div>
7569
- </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell">
7570
  <div class="jp-Cell-inputWrapper" tabindex="0">
7571
  <div class="jp-Collapser jp-InputCollapser jp-Cell-inputCollapser">
7572
  </div>
@@ -7582,18 +7540,6 @@ Special tokens have been added in the vocabulary, make sure the associated word
7582
  </div>
7583
  </div>
7584
  </div>
7585
- <div class="jp-Cell-outputWrapper">
7586
- <div class="jp-Collapser jp-OutputCollapser jp-Cell-outputCollapser">
7587
- </div>
7588
- <div class="jp-OutputArea jp-Cell-outputArea">
7589
- <div class="jp-OutputArea-child jp-OutputArea-executeResult">
7590
- <div class="jp-OutputPrompt jp-OutputArea-prompt">Out[ ]:</div>
7591
- <div class="jp-RenderedText jp-OutputArea-output jp-OutputArea-executeResult" data-mime-type="text/plain" tabindex="0">
7592
- <pre>[{'label': 'Computers_and_Electronics', 'score': 0.9995001554489136}]</pre>
7593
- </div>
7594
- </div>
7595
- </div>
7596
- </div>
7597
  </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell jp-mod-noOutputs">
7598
  <div class="jp-Cell-inputWrapper" tabindex="0">
7599
  <div class="jp-Collapser jp-InputCollapser jp-Cell-inputCollapser">
@@ -7619,7 +7565,7 @@ Special tokens have been added in the vocabulary, make sure the associated word
7619
  </div>
7620
  </div>
7621
  </div>
7622
- </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell jp-mod-noOutputs">
7623
  <div class="jp-Cell-inputWrapper" tabindex="0">
7624
  <div class="jp-Collapser jp-InputCollapser jp-Cell-inputCollapser">
7625
  </div>
@@ -7628,11 +7574,24 @@ Special tokens have been added in the vocabulary, make sure the associated word
7628
  <div class="jp-CodeMirrorEditor jp-Editor jp-InputArea-editor" data-type="inline">
7629
  <div class="cm-editor cm-s-jupyter">
7630
  <div class="highlight hl-ipython3"><pre><span></span><span class="kn">import</span> <span class="nn">os</span><span class="p">;</span> <span class="n">os</span><span class="o">.</span><span class="n">chdir</span><span class="p">(</span><span class="s1">'..'</span><span class="p">)</span>
 
7631
  </pre></div>
7632
  </div>
7633
  </div>
7634
  </div>
7635
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
7636
  </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell jp-mod-noOutputs">
7637
  <div class="jp-Cell-inputWrapper" tabindex="0">
7638
  <div class="jp-Collapser jp-InputCollapser jp-Cell-inputCollapser">
@@ -7668,7 +7627,7 @@ Special tokens have been added in the vocabulary, make sure the associated word
7668
  <span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">functional</span> <span class="k">as</span> <span class="n">F</span>
7669
 
7670
 
7671
- <span class="n">model_name</span><span class="o">=</span> <span class="s2">"finetuned_entity_categorical_classification/checkpoint-23355"</span>
7672
  <span class="n">tokenizer</span> <span class="o">=</span> <span class="n">AutoTokenizer</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="n">model_name</span><span class="p">)</span>
7673
 
7674
  <span class="n">model</span> <span class="o">=</span> <span class="n">AutoModelForSequenceClassification</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="n">model_name</span><span class="p">)</span>
@@ -7708,14 +7667,20 @@ Special tokens have been added in the vocabulary, make sure the associated word
7708
  <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
7709
  <span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="o">**</span><span class="n">inputs</span><span class="p">)</span><span class="o">.</span><span class="n">logits</span>
7710
 
7711
- <span class="nb">print</span><span class="p">(</span><span class="s2">"logits: "</span><span class="p">,</span> <span class="n">logits</span><span class="p">)</span>
7712
  <span class="n">predicted_class_id</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">argmax</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
7713
  <span class="c1"># get probabilities using softmax from logit score and convert it to numpy array</span>
7714
  <span class="n">probabilities_scores</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">dim</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">numpy</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span>
 
7715
  <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">27</span><span class="p">):</span>
7716
- <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"P(</span><span class="si">{</span><span class="n">id2label</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="si">}</span><span class="s2">): </span><span class="si">{</span><span class="n">probabilities_scores</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
 
 
 
 
7717
 
7718
- <span class="nb">print</span><span class="p">(</span><span class="s2">"Predicted Class: "</span><span class="p">,</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">id2label</span><span class="p">[</span><span class="n">predicted_class_id</span><span class="p">])</span>
 
7719
 
7720
 
7721
 
@@ -7745,41 +7710,44 @@ Special tokens have been added in the vocabulary, make sure the associated word
7745
  <div class="jp-OutputArea-child">
7746
  <div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
7747
  <div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain" tabindex="0">
7748
- <pre>logits: tensor([[-1.4210, -4.3130, -4.1497, -1.9217, -3.3253, -2.3839, -3.2943, -4.6091,
7749
- -1.9258, -3.6359, -3.7877, -4.2664, -5.1229, -1.8067, -4.4068, -4.5855,
7750
- 10.0302, 0.0293, -2.0481, -5.8791, -3.7072, -3.1037, -4.1602, -0.8520,
7751
- -3.6628, -4.5927, -4.0272]])
7752
- P(Beauty_and_Fitness): 1.0635108992573805e-05
7753
- P(People_and_Society): 5.899146344745532e-07
7754
- P(Travel_and_Transportation): 6.945512041056645e-07
7755
- P(Shopping): 6.446343832067214e-06
7756
- P(Adult): 1.583859898346418e-06
7757
- P(Sports): 4.060307219333481e-06
7758
- P(Science): 1.6337769466190366e-06
7759
- P(Food_and_Drink): 4.3873527033611026e-07
7760
- P(News): 6.419656983780442e-06
7761
- P(Sensitive Subjects): 1.1609599823714234e-06
7762
- P(Autos_and_Vehicles): 9.975190096156439e-07
7763
- P(Law_and_Government): 6.180094374030887e-07
7764
- P(Business_and_Industrial): 2.6243591833008395e-07
7765
- P(Health): 7.231980362121249e-06
7766
- P(Real Estate): 5.370690701056446e-07
7767
- P(Books_and_Literature): 4.492034122449695e-07
7768
- P(Computers_and_Electronics): 0.9998801946640015
7769
- P(Internet_and_Telecom): 4.535169500741176e-05
7770
- P(Home_and_Garden): 5.680800768459449e-06
7771
- P(Jobs_and_Education): 1.2321044096097467e-07
7772
- P(Online Communities): 1.081151822290849e-06
7773
- P(Finance): 1.976913608814357e-06
7774
- P(Arts_and_Entertainment): 6.872939479762863e-07
7775
- P(Games): 1.8787852241075598e-05
7776
- P(Hobbies_and_Leisure): 1.1302184930173098e-06
7777
- P(Reference): 4.4596322368306573e-07
7778
- P(Pets_and_Animals): 7.850715633139771e-07
7779
- Predicted Class: Computers_and_Electronics
7780
  </pre>
7781
  </div>
7782
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7783
  </div>
7784
  </div>
7785
  </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell">
@@ -7803,41 +7771,44 @@ Predicted Class: Computers_and_Electronics
7803
  <div class="jp-OutputArea-child">
7804
  <div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
7805
  <div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain" tabindex="0">
7806
- <pre>logits: tensor([[-0.0132, -3.1295, -4.6255, -4.0011, -6.0542, -1.8381, -2.5773, -0.6340,
7807
- -2.6285, -4.2835, -4.8787, -4.8718, -5.0328, 3.7794, -4.6336, -5.0319,
7808
- 1.0592, -4.5573, -3.8045, -5.2772, -5.0804, -3.2632, -4.1335, -3.5920,
7809
- -2.1358, -7.6210, 3.6940]])
7810
- P(Beauty_and_Fitness): 0.011079358868300915
7811
- P(People_and_Society): 0.0004910691641271114
7812
- P(Travel_and_Transportation): 0.00011000979429809377
7813
- P(Shopping): 0.00020539172692224383
7814
- P(Adult): 2.635990676935762e-05
7815
- P(Sports): 0.0017864161636680365
7816
- P(Science): 0.0008529641781933606
7817
- P(Food_and_Drink): 0.005955575965344906
7818
- P(News): 0.000810392084531486
7819
- P(Sensitive Subjects): 0.00015485959011130035
7820
- P(Autos_and_Vehicles): 8.5399005911313e-05
7821
- P(Law_and_Government): 8.598815475124866e-05
7822
- P(Business_and_Industrial): 7.320548320421949e-05
7823
- P(Health): 0.4916036128997803
7824
- P(Real Estate): 0.0001091243393602781
7825
- P(Books_and_Literature): 7.327288767555729e-05
7826
- P(Computers_and_Electronics): 0.03238002583384514
7827
- P(Internet_and_Telecom): 0.00011777772306231782
7828
- P(Home_and_Garden): 0.0002500169211998582
7829
- P(Jobs_and_Education): 5.733156285714358e-05
7830
- P(Online Communities): 6.979802856221795e-05
7831
- P(Finance): 0.00042960469727404416
7832
- P(Arts_and_Entertainment): 0.000179934679181315
7833
- P(Games): 0.00030923119629733264
7834
- P(Hobbies_and_Leisure): 0.0013263950822874904
7835
- P(Reference): 5.501774921867764e-06
7836
- P(Pets_and_Animals): 0.4513714015483856
7837
- Predicted Class: Health
7838
  </pre>
7839
  </div>
7840
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7841
  </div>
7842
  </div>
7843
  </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell">
@@ -7861,41 +7832,44 @@ Predicted Class: Health
7861
  <div class="jp-OutputArea-child">
7862
  <div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
7863
  <div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain" tabindex="0">
7864
- <pre>logits: tensor([[ -0.9597, -3.8456, -3.1203, -1.9988, -5.4966, -3.5321, -2.2676,
7865
- 9.3689, 0.3687, -4.5561, -5.4510, -3.8708, -4.3223, 0.2038,
7866
- -3.1802, -3.6065, -3.8467, -4.6997, -3.8446, -4.4849, -4.4130,
7867
- -2.8653, -2.8191, -4.9874, -1.7339, -10.3458, -1.0289]])
7868
- P(Beauty_and_Fitness): 3.267208012402989e-05
7869
- P(People_and_Society): 1.8231645526611828e-06
7870
- P(Travel_and_Transportation): 3.7654806419595843e-06
7871
- P(Shopping): 1.1558237929421011e-05
7872
- P(Adult): 3.4979228757947567e-07
7873
- P(Sports): 2.4945670702436473e-06
7874
- P(Science): 8.83362372405827e-06
7875
- P(Food_and_Drink): 0.9996380805969238
7876
- P(News): 0.00012333830818533897
7877
- P(Sensitive Subjects): 8.959448223322397e-07
7878
- P(Autos_and_Vehicles): 3.6612007647818245e-07
7879
- P(Law_and_Government): 1.7778713754523778e-06
7880
- P(Business_and_Industrial): 1.13186013095401e-06
7881
- P(Health): 0.0001045860699377954
7882
- P(Real Estate): 3.5467155612423085e-06
7883
- P(Books_and_Literature): 2.3157517716754228e-06
7884
- P(Computers_and_Electronics): 1.821160935833177e-06
7885
- P(Internet_and_Telecom): 7.761184406263055e-07
7886
- P(Home_and_Garden): 1.8250555058330065e-06
7887
- P(Jobs_and_Education): 9.62060425990785e-07
7888
- P(Online Communities): 1.033720309351338e-06
7889
- P(Finance): 4.85956570628332e-06
7890
- P(Arts_and_Entertainment): 5.089193109597545e-06
7891
- P(Games): 5.820724595650972e-07
7892
- P(Hobbies_and_Leisure): 1.5064177205204032e-05
7893
- P(Reference): 2.740457638594762e-09
7894
- P(Pets_and_Animals): 3.0487293770420365e-05
7895
- Predicted Class: Food_and_Drink
7896
  </pre>
7897
  </div>
7898
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7899
  </div>
7900
  </div>
7901
  </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell">
@@ -7919,41 +7893,44 @@ Predicted Class: Food_and_Drink
7919
  <div class="jp-OutputArea-child">
7920
  <div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
7921
  <div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain" tabindex="0">
7922
- <pre>logits: tensor([[-3.9272, -2.2786, -3.7970, -3.0280, -3.9465, -1.5384, -1.8026, -2.8501,
7923
- -2.0297, -4.1079, -3.4422, -3.7435, -3.6991, -2.9137, -2.7802, -4.9385,
7924
- -1.6897, -3.3684, -2.7991, -4.0702, -4.0103, -2.6430, -3.3914, -4.5762,
7925
- -2.0696, -7.2857, 10.6981]])
7926
- P(Beauty_and_Fitness): 4.4490917616712977e-07
7927
- P(People_and_Society): 2.3135853552957997e-06
7928
- P(Travel_and_Transportation): 5.068139330433041e-07
7929
- P(Shopping): 1.0934879810520215e-06
7930
- P(Adult): 4.364295307368593e-07
7931
- P(Sports): 4.849702690989943e-06
7932
- P(Science): 3.723783038367401e-06
7933
- P(Food_and_Drink): 1.306354533880949e-06
7934
- P(News): 2.9673019525944255e-06
7935
- P(Sensitive Subjects): 3.7138897823751904e-07
7936
- P(Autos_and_Vehicles): 7.226597631415643e-07
7937
- P(Law_and_Government): 5.346369107428472e-07
7938
- P(Business_and_Industrial): 5.58940200789948e-07
7939
- P(Health): 1.2258614106031018e-06
7940
- P(Real Estate): 1.4009098094902583e-06
7941
- P(Books_and_Literature): 1.6184709750177717e-07
7942
- P(Computers_and_Electronics): 4.168971827311907e-06
7943
- P(Internet_and_Telecom): 7.780183182148903e-07
7944
- P(Home_and_Garden): 1.3746708873441094e-06
7945
- P(Jobs_and_Education): 3.856556816117518e-07
7946
- P(Online Communities): 4.094476082627807e-07
7947
- P(Finance): 1.6070013089120039e-06
7948
- P(Arts_and_Entertainment): 7.603246672260866e-07
7949
- P(Games): 2.3251060099482856e-07
7950
- P(Hobbies_and_Leisure): 2.8512215521914186e-06
7951
- P(Reference): 1.5477359838200755e-08
7952
- P(Pets_and_Animals): 0.9999648332595825
7953
- Predicted Class: Pets_and_Animals
7954
  </pre>
7955
  </div>
7956
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7957
  </div>
7958
  </div>
7959
  </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell">
@@ -7977,41 +7954,351 @@ Predicted Class: Pets_and_Animals
7977
  <div class="jp-OutputArea-child">
7978
  <div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
7979
  <div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain" tabindex="0">
7980
- <pre>logits: tensor([[-1.5957, -4.4011, -4.4274, -1.5319, -3.9572, -2.2991, -3.6216, -5.3450,
7981
- -2.7176, -3.9352, -4.0612, -4.6522, -5.7079, -1.6673, -4.3583, -4.3791,
7982
- 9.6847, -0.6290, -2.0402, -5.4800, -4.3648, -2.9588, -4.5169, -1.1335,
7983
- -3.7419, -4.2007, -4.0720]])
7984
- P(Beauty_and_Fitness): 1.2615569175977726e-05
7985
- P(People_and_Society): 7.630294476257404e-07
7986
- P(Travel_and_Transportation): 7.431989388351212e-07
7987
- P(Shopping): 1.3446799130178988e-05
7988
- P(Adult): 1.189393287859275e-06
7989
- P(Sports): 6.243569714570185e-06
7990
- P(Science): 1.6636606687825406e-06
7991
- P(Food_and_Drink): 2.969020158616331e-07
7992
- P(News): 4.108236680622213e-06
7993
- P(Sensitive Subjects): 1.2158897106928634e-06
7994
- P(Autos_and_Vehicles): 1.0719173815232352e-06
7995
- P(Law_and_Government): 5.935727926953405e-07
7996
- P(Business_and_Industrial): 2.0653641286116908e-07
7997
- P(Health): 1.1743918548745569e-05
7998
- P(Real Estate): 7.964112000991008e-07
7999
- P(Books_and_Literature): 7.800075536579243e-07
8000
- P(Computers_and_Electronics): 0.9998729228973389
8001
- P(Internet_and_Telecom): 3.316840957268141e-05
8002
- P(Home_and_Garden): 8.08808999863686e-06
8003
- P(Jobs_and_Education): 2.5939584702427965e-07
8004
- P(Online Communities): 7.912186674730037e-07
8005
- P(Finance): 3.2278749131364748e-06
8006
- P(Arts_and_Entertainment): 6.795915510338091e-07
8007
- P(Games): 2.00279555429006e-05
8008
- P(Hobbies_and_Leisure): 1.4751183243788546e-06
8009
- P(Reference): 9.323083531853626e-07
8010
- P(Pets_and_Animals): 1.060413751474698e-06
8011
- Predicted Class: Computers_and_Electronics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8012
  </pre>
8013
  </div>
8014
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8015
  </div>
8016
  </div>
8017
  </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell jp-mod-noOutputs">
 
7475
  </style>
7476
  <!-- End of mermaid configuration --></head>
7477
  <body class="jp-Notebook" data-jp-theme-light="true" data-jp-theme-name="JupyterLab Light">
7478
+ <main><div class="jp-Cell jp-CodeCell jp-Notebook-cell jp-mod-noOutputs">
7479
  <div class="jp-Cell-inputWrapper" tabindex="0">
7480
  <div class="jp-Collapser jp-InputCollapser jp-Cell-inputCollapser">
7481
  </div>
 
7492
  </div>
7493
  </div>
7494
  </div>
7495
+ </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell jp-mod-noOutputs">
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7496
  <div class="jp-Cell-inputWrapper" tabindex="0">
7497
  <div class="jp-Collapser jp-InputCollapser jp-Cell-inputCollapser">
7498
  </div>
 
7508
  </div>
7509
  </div>
7510
  </div>
7511
+ </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell jp-mod-noOutputs">
 
 
 
 
 
 
 
 
 
 
 
 
7512
  <div class="jp-Cell-inputWrapper" tabindex="0">
7513
  <div class="jp-Collapser jp-InputCollapser jp-Cell-inputCollapser">
7514
  </div>
 
7524
  </div>
7525
  </div>
7526
  </div>
7527
+ </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell jp-mod-noOutputs">
 
 
 
 
 
 
 
 
 
 
 
 
7528
  <div class="jp-Cell-inputWrapper" tabindex="0">
7529
  <div class="jp-Collapser jp-InputCollapser jp-Cell-inputCollapser">
7530
  </div>
 
7540
  </div>
7541
  </div>
7542
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
7543
  </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell jp-mod-noOutputs">
7544
  <div class="jp-Cell-inputWrapper" tabindex="0">
7545
  <div class="jp-Collapser jp-InputCollapser jp-Cell-inputCollapser">
 
7565
  </div>
7566
  </div>
7567
  </div>
7568
+ </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell">
7569
  <div class="jp-Cell-inputWrapper" tabindex="0">
7570
  <div class="jp-Collapser jp-InputCollapser jp-Cell-inputCollapser">
7571
  </div>
 
7574
  <div class="jp-CodeMirrorEditor jp-Editor jp-InputArea-editor" data-type="inline">
7575
  <div class="cm-editor cm-s-jupyter">
7576
  <div class="highlight hl-ipython3"><pre><span></span><span class="kn">import</span> <span class="nn">os</span><span class="p">;</span> <span class="n">os</span><span class="o">.</span><span class="n">chdir</span><span class="p">(</span><span class="s1">'..'</span><span class="p">)</span>
7577
+ <span class="o">%</span><span class="k">pwd</span>
7578
  </pre></div>
7579
  </div>
7580
  </div>
7581
  </div>
7582
  </div>
7583
+ <div class="jp-Cell-outputWrapper">
7584
+ <div class="jp-Collapser jp-OutputCollapser jp-Cell-outputCollapser">
7585
+ </div>
7586
+ <div class="jp-OutputArea jp-Cell-outputArea">
7587
+ <div class="jp-OutputArea-child jp-OutputArea-executeResult">
7588
+ <div class="jp-OutputPrompt jp-OutputArea-prompt">Out[ ]:</div>
7589
+ <div class="jp-RenderedText jp-OutputArea-output jp-OutputArea-executeResult" data-mime-type="text/plain" tabindex="0">
7590
+ <pre>'/home/ubuntu/SentenceStructureComparision'</pre>
7591
+ </div>
7592
+ </div>
7593
+ </div>
7594
+ </div>
7595
  </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell jp-mod-noOutputs">
7596
  <div class="jp-Cell-inputWrapper" tabindex="0">
7597
  <div class="jp-Collapser jp-InputCollapser jp-Cell-inputCollapser">
 
7627
  <span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">functional</span> <span class="k">as</span> <span class="n">F</span>
7628
 
7629
 
7630
+ <span class="n">model_name</span><span class="o">=</span> <span class="s2">"finetuned_entity_categorical_classification/checkpoint-3212"</span>
7631
  <span class="n">tokenizer</span> <span class="o">=</span> <span class="n">AutoTokenizer</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="n">model_name</span><span class="p">)</span>
7632
 
7633
  <span class="n">model</span> <span class="o">=</span> <span class="n">AutoModelForSequenceClassification</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="n">model_name</span><span class="p">)</span>
 
7667
  <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
7668
  <span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="o">**</span><span class="n">inputs</span><span class="p">)</span><span class="o">.</span><span class="n">logits</span>
7669
 
7670
+ <span class="c1"># print("logits: ", logits)</span>
7671
  <span class="n">predicted_class_id</span> <span class="o">=</span> <span class="n">logits</span><span class="o">.</span><span class="n">argmax</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
7672
  <span class="c1"># get probabilities using softmax from logit score and convert it to numpy array</span>
7673
  <span class="n">probabilities_scores</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">dim</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">numpy</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span>
7674
+ <span class="n">d</span><span class="o">=</span> <span class="p">{}</span>
7675
  <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">27</span><span class="p">):</span>
7676
+ <span class="c1"># print(f"P({id2label[i]}): {probabilities_scores[i]}")</span>
7677
+ <span class="c1"># d[f'P({id2label[i]})']= format(probabilities_scores[i], '.2f')</span>
7678
+ <span class="n">d</span><span class="p">[</span><span class="sa">f</span><span class="s1">'P(</span><span class="si">{</span><span class="n">id2label</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="si">}</span><span class="s1">)'</span><span class="p">]</span><span class="o">=</span> <span class="nb">round</span><span class="p">(</span><span class="n">probabilities_scores</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="mi">3</span><span class="p">)</span>
7679
+
7680
+
7681
 
7682
+ <span class="nb">print</span><span class="p">(</span><span class="s2">"Predicted Class: "</span><span class="p">,</span> <span class="n">model</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">id2label</span><span class="p">[</span><span class="n">predicted_class_id</span><span class="p">],</span> <span class="sa">f</span><span class="s2">"</span><span class="se">\n</span><span class="s2">probabilities_scores: </span><span class="si">{</span><span class="n">probabilities_scores</span><span class="p">[</span><span class="n">predicted_class_id</span><span class="p">]</span><span class="si">}</span><span class="se">\n</span><span class="s2">"</span><span class="p">)</span>
7683
+ <span class="k">return</span> <span class="n">d</span>
7684
 
7685
 
7686
 
 
7710
  <div class="jp-OutputArea-child">
7711
  <div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
7712
  <div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain" tabindex="0">
7713
+ <pre>Predicted Class: Computers_and_Electronics
7714
+ probabilities_scores: 0.9997648596763611
7715
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7716
  </pre>
7717
  </div>
7718
  </div>
7719
+ <div class="jp-OutputArea-child jp-OutputArea-executeResult">
7720
+ <div class="jp-OutputPrompt jp-OutputArea-prompt">Out[ ]:</div>
7721
+ <div class="jp-RenderedText jp-OutputArea-output jp-OutputArea-executeResult" data-mime-type="text/plain" tabindex="0">
7722
+ <pre>{'P(Hobbies_and_Leisure)': 0.0,
7723
+ 'P(News)': 0.0,
7724
+ 'P(Science)': 0.0,
7725
+ 'P(Autos_and_Vehicles)': 0.0,
7726
+ 'P(Health)': 0.0,
7727
+ 'P(Pets_and_Animals)': 0.0,
7728
+ 'P(Adult)': 0.0,
7729
+ 'P(Computers_and_Electronics)': 1.0,
7730
+ 'P(Online Communities)': 0.0,
7731
+ 'P(Beauty_and_Fitness)': 0.0,
7732
+ 'P(People_and_Society)': 0.0,
7733
+ 'P(Business_and_Industrial)': 0.0,
7734
+ 'P(Reference)': 0.0,
7735
+ 'P(Shopping)': 0.0,
7736
+ 'P(Travel_and_Transportation)': 0.0,
7737
+ 'P(Food_and_Drink)': 0.0,
7738
+ 'P(Law_and_Government)': 0.0,
7739
+ 'P(Books_and_Literature)': 0.0,
7740
+ 'P(Finance)': 0.0,
7741
+ 'P(Games)': 0.0,
7742
+ 'P(Home_and_Garden)': 0.0,
7743
+ 'P(Jobs_and_Education)': 0.0,
7744
+ 'P(Arts_and_Entertainment)': 0.0,
7745
+ 'P(Sensitive Subjects)': 0.0,
7746
+ 'P(Real Estate)': 0.0,
7747
+ 'P(Internet_and_Telecom)': 0.0,
7748
+ 'P(Sports)': 0.0}</pre>
7749
+ </div>
7750
+ </div>
7751
  </div>
7752
  </div>
7753
  </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell">
 
7771
  <div class="jp-OutputArea-child">
7772
  <div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
7773
  <div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain" tabindex="0">
7774
+ <pre>Predicted Class: Food_and_Drink
7775
+ probabilities_scores: 0.9993139505386353
7776
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7777
  </pre>
7778
  </div>
7779
  </div>
7780
+ <div class="jp-OutputArea-child jp-OutputArea-executeResult">
7781
+ <div class="jp-OutputPrompt jp-OutputArea-prompt">Out[ ]:</div>
7782
+ <div class="jp-RenderedText jp-OutputArea-output jp-OutputArea-executeResult" data-mime-type="text/plain" tabindex="0">
7783
+ <pre>{'P(Hobbies_and_Leisure)': 0.0,
7784
+ 'P(News)': 0.0,
7785
+ 'P(Science)': 0.0,
7786
+ 'P(Autos_and_Vehicles)': 0.0,
7787
+ 'P(Health)': 0.0,
7788
+ 'P(Pets_and_Animals)': 0.0,
7789
+ 'P(Adult)': 0.0,
7790
+ 'P(Computers_and_Electronics)': 0.0,
7791
+ 'P(Online Communities)': 0.0,
7792
+ 'P(Beauty_and_Fitness)': 0.0,
7793
+ 'P(People_and_Society)': 0.0,
7794
+ 'P(Business_and_Industrial)': 0.0,
7795
+ 'P(Reference)': 0.0,
7796
+ 'P(Shopping)': 0.0,
7797
+ 'P(Travel_and_Transportation)': 0.0,
7798
+ 'P(Food_and_Drink)': 0.999,
7799
+ 'P(Law_and_Government)': 0.0,
7800
+ 'P(Books_and_Literature)': 0.0,
7801
+ 'P(Finance)': 0.0,
7802
+ 'P(Games)': 0.0,
7803
+ 'P(Home_and_Garden)': 0.0,
7804
+ 'P(Jobs_and_Education)': 0.0,
7805
+ 'P(Arts_and_Entertainment)': 0.0,
7806
+ 'P(Sensitive Subjects)': 0.0,
7807
+ 'P(Real Estate)': 0.0,
7808
+ 'P(Internet_and_Telecom)': 0.0,
7809
+ 'P(Sports)': 0.0}</pre>
7810
+ </div>
7811
+ </div>
7812
  </div>
7813
  </div>
7814
  </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell">
 
7832
  <div class="jp-OutputArea-child">
7833
  <div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
7834
  <div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain" tabindex="0">
7835
+ <pre>Predicted Class: Food_and_Drink
7836
+ probabilities_scores: 0.9997541308403015
7837
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7838
  </pre>
7839
  </div>
7840
  </div>
7841
+ <div class="jp-OutputArea-child jp-OutputArea-executeResult">
7842
+ <div class="jp-OutputPrompt jp-OutputArea-prompt">Out[ ]:</div>
7843
+ <div class="jp-RenderedText jp-OutputArea-output jp-OutputArea-executeResult" data-mime-type="text/plain" tabindex="0">
7844
+ <pre>{'P(Hobbies_and_Leisure)': 0.0,
7845
+ 'P(News)': 0.0,
7846
+ 'P(Science)': 0.0,
7847
+ 'P(Autos_and_Vehicles)': 0.0,
7848
+ 'P(Health)': 0.0,
7849
+ 'P(Pets_and_Animals)': 0.0,
7850
+ 'P(Adult)': 0.0,
7851
+ 'P(Computers_and_Electronics)': 0.0,
7852
+ 'P(Online Communities)': 0.0,
7853
+ 'P(Beauty_and_Fitness)': 0.0,
7854
+ 'P(People_and_Society)': 0.0,
7855
+ 'P(Business_and_Industrial)': 0.0,
7856
+ 'P(Reference)': 0.0,
7857
+ 'P(Shopping)': 0.0,
7858
+ 'P(Travel_and_Transportation)': 0.0,
7859
+ 'P(Food_and_Drink)': 1.0,
7860
+ 'P(Law_and_Government)': 0.0,
7861
+ 'P(Books_and_Literature)': 0.0,
7862
+ 'P(Finance)': 0.0,
7863
+ 'P(Games)': 0.0,
7864
+ 'P(Home_and_Garden)': 0.0,
7865
+ 'P(Jobs_and_Education)': 0.0,
7866
+ 'P(Arts_and_Entertainment)': 0.0,
7867
+ 'P(Sensitive Subjects)': 0.0,
7868
+ 'P(Real Estate)': 0.0,
7869
+ 'P(Internet_and_Telecom)': 0.0,
7870
+ 'P(Sports)': 0.0}</pre>
7871
+ </div>
7872
+ </div>
7873
  </div>
7874
  </div>
7875
  </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell">
 
7893
  <div class="jp-OutputArea-child">
7894
  <div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
7895
  <div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain" tabindex="0">
7896
+ <pre>Predicted Class: Food_and_Drink
7897
+ probabilities_scores: 0.9963496923446655
7898
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7899
  </pre>
7900
  </div>
7901
  </div>
7902
+ <div class="jp-OutputArea-child jp-OutputArea-executeResult">
7903
+ <div class="jp-OutputPrompt jp-OutputArea-prompt">Out[ ]:</div>
7904
+ <div class="jp-RenderedText jp-OutputArea-output jp-OutputArea-executeResult" data-mime-type="text/plain" tabindex="0">
7905
+ <pre>{'P(Hobbies_and_Leisure)': 0.0,
7906
+ 'P(News)': 0.0,
7907
+ 'P(Science)': 0.0,
7908
+ 'P(Autos_and_Vehicles)': 0.0,
7909
+ 'P(Health)': 0.0,
7910
+ 'P(Pets_and_Animals)': 0.002,
7911
+ 'P(Adult)': 0.0,
7912
+ 'P(Computers_and_Electronics)': 0.0,
7913
+ 'P(Online Communities)': 0.0,
7914
+ 'P(Beauty_and_Fitness)': 0.0,
7915
+ 'P(People_and_Society)': 0.0,
7916
+ 'P(Business_and_Industrial)': 0.0,
7917
+ 'P(Reference)': 0.0,
7918
+ 'P(Shopping)': 0.0,
7919
+ 'P(Travel_and_Transportation)': 0.0,
7920
+ 'P(Food_and_Drink)': 0.996,
7921
+ 'P(Law_and_Government)': 0.0,
7922
+ 'P(Books_and_Literature)': 0.0,
7923
+ 'P(Finance)': 0.0,
7924
+ 'P(Games)': 0.0,
7925
+ 'P(Home_and_Garden)': 0.0,
7926
+ 'P(Jobs_and_Education)': 0.0,
7927
+ 'P(Arts_and_Entertainment)': 0.0,
7928
+ 'P(Sensitive Subjects)': 0.0,
7929
+ 'P(Real Estate)': 0.0,
7930
+ 'P(Internet_and_Telecom)': 0.0,
7931
+ 'P(Sports)': 0.0}</pre>
7932
+ </div>
7933
+ </div>
7934
  </div>
7935
  </div>
7936
  </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell">
 
7954
  <div class="jp-OutputArea-child">
7955
  <div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
7956
  <div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain" tabindex="0">
7957
+ <pre>Predicted Class: Computers_and_Electronics
7958
+ probabilities_scores: 0.999832034111023
7959
+
7960
+ </pre>
7961
+ </div>
7962
+ </div>
7963
+ <div class="jp-OutputArea-child jp-OutputArea-executeResult">
7964
+ <div class="jp-OutputPrompt jp-OutputArea-prompt">Out[ ]:</div>
7965
+ <div class="jp-RenderedText jp-OutputArea-output jp-OutputArea-executeResult" data-mime-type="text/plain" tabindex="0">
7966
+ <pre>{'P(Hobbies_and_Leisure)': 0.0,
7967
+ 'P(News)': 0.0,
7968
+ 'P(Science)': 0.0,
7969
+ 'P(Autos_and_Vehicles)': 0.0,
7970
+ 'P(Health)': 0.0,
7971
+ 'P(Pets_and_Animals)': 0.0,
7972
+ 'P(Adult)': 0.0,
7973
+ 'P(Computers_and_Electronics)': 1.0,
7974
+ 'P(Online Communities)': 0.0,
7975
+ 'P(Beauty_and_Fitness)': 0.0,
7976
+ 'P(People_and_Society)': 0.0,
7977
+ 'P(Business_and_Industrial)': 0.0,
7978
+ 'P(Reference)': 0.0,
7979
+ 'P(Shopping)': 0.0,
7980
+ 'P(Travel_and_Transportation)': 0.0,
7981
+ 'P(Food_and_Drink)': 0.0,
7982
+ 'P(Law_and_Government)': 0.0,
7983
+ 'P(Books_and_Literature)': 0.0,
7984
+ 'P(Finance)': 0.0,
7985
+ 'P(Games)': 0.0,
7986
+ 'P(Home_and_Garden)': 0.0,
7987
+ 'P(Jobs_and_Education)': 0.0,
7988
+ 'P(Arts_and_Entertainment)': 0.0,
7989
+ 'P(Sensitive Subjects)': 0.0,
7990
+ 'P(Real Estate)': 0.0,
7991
+ 'P(Internet_and_Telecom)': 0.0,
7992
+ 'P(Sports)': 0.0}</pre>
7993
+ </div>
7994
+ </div>
7995
+ </div>
7996
+ </div>
7997
+ </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell">
7998
+ <div class="jp-Cell-inputWrapper" tabindex="0">
7999
+ <div class="jp-Collapser jp-InputCollapser jp-Cell-inputCollapser">
8000
+ </div>
8001
+ <div class="jp-InputArea jp-Cell-inputArea">
8002
+ <div class="jp-InputPrompt jp-InputArea-prompt">In [ ]:</div>
8003
+ <div class="jp-CodeMirrorEditor jp-Editor jp-InputArea-editor" data-type="inline">
8004
+ <div class="cm-editor cm-s-jupyter">
8005
+ <div class="highlight hl-ipython3"><pre><span></span><span class="n">predict</span><span class="p">(</span><span class="s2">"apple "</span><span class="p">)</span>
8006
+ </pre></div>
8007
+ </div>
8008
+ </div>
8009
+ </div>
8010
+ </div>
8011
+ <div class="jp-Cell-outputWrapper">
8012
+ <div class="jp-Collapser jp-OutputCollapser jp-Cell-outputCollapser">
8013
+ </div>
8014
+ <div class="jp-OutputArea jp-Cell-outputArea">
8015
+ <div class="jp-OutputArea-child">
8016
+ <div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
8017
+ <div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain" tabindex="0">
8018
+ <pre>Predicted Class: Food_and_Drink
8019
+ probabilities_scores: 0.5473537445068359
8020
+
8021
  </pre>
8022
  </div>
8023
  </div>
8024
+ <div class="jp-OutputArea-child jp-OutputArea-executeResult">
8025
+ <div class="jp-OutputPrompt jp-OutputArea-prompt">Out[ ]:</div>
8026
+ <div class="jp-RenderedText jp-OutputArea-output jp-OutputArea-executeResult" data-mime-type="text/plain" tabindex="0">
8027
+ <pre>{'P(Hobbies_and_Leisure)': 0.0,
8028
+ 'P(News)': 0.0,
8029
+ 'P(Science)': 0.0,
8030
+ 'P(Autos_and_Vehicles)': 0.0,
8031
+ 'P(Health)': 0.0,
8032
+ 'P(Pets_and_Animals)': 0.0,
8033
+ 'P(Adult)': 0.0,
8034
+ 'P(Computers_and_Electronics)': 0.448,
8035
+ 'P(Online Communities)': 0.0,
8036
+ 'P(Beauty_and_Fitness)': 0.0,
8037
+ 'P(People_and_Society)': 0.0,
8038
+ 'P(Business_and_Industrial)': 0.0,
8039
+ 'P(Reference)': 0.0,
8040
+ 'P(Shopping)': 0.001,
8041
+ 'P(Travel_and_Transportation)': 0.0,
8042
+ 'P(Food_and_Drink)': 0.547,
8043
+ 'P(Law_and_Government)': 0.0,
8044
+ 'P(Books_and_Literature)': 0.0,
8045
+ 'P(Finance)': 0.0,
8046
+ 'P(Games)': 0.002,
8047
+ 'P(Home_and_Garden)': 0.0,
8048
+ 'P(Jobs_and_Education)': 0.0,
8049
+ 'P(Arts_and_Entertainment)': 0.0,
8050
+ 'P(Sensitive Subjects)': 0.0,
8051
+ 'P(Real Estate)': 0.0,
8052
+ 'P(Internet_and_Telecom)': 0.0,
8053
+ 'P(Sports)': 0.0}</pre>
8054
+ </div>
8055
+ </div>
8056
+ </div>
8057
+ </div>
8058
+ </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell">
8059
+ <div class="jp-Cell-inputWrapper" tabindex="0">
8060
+ <div class="jp-Collapser jp-InputCollapser jp-Cell-inputCollapser">
8061
+ </div>
8062
+ <div class="jp-InputArea jp-Cell-inputArea">
8063
+ <div class="jp-InputPrompt jp-InputArea-prompt">In [ ]:</div>
8064
+ <div class="jp-CodeMirrorEditor jp-Editor jp-InputArea-editor" data-type="inline">
8065
+ <div class="cm-editor cm-s-jupyter">
8066
+ <div class="highlight hl-ipython3"><pre><span></span><span class="n">predict</span><span class="p">(</span><span class="s1">'apple iphone'</span><span class="p">)</span>
8067
+ </pre></div>
8068
+ </div>
8069
+ </div>
8070
+ </div>
8071
+ </div>
8072
+ <div class="jp-Cell-outputWrapper">
8073
+ <div class="jp-Collapser jp-OutputCollapser jp-Cell-outputCollapser">
8074
+ </div>
8075
+ <div class="jp-OutputArea jp-Cell-outputArea">
8076
+ <div class="jp-OutputArea-child">
8077
+ <div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
8078
+ <div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain" tabindex="0">
8079
+ <pre>Predicted Class: Computers_and_Electronics
8080
+ probabilities_scores: 0.9997270703315735
8081
+
8082
+ </pre>
8083
+ </div>
8084
+ </div>
8085
+ <div class="jp-OutputArea-child jp-OutputArea-executeResult">
8086
+ <div class="jp-OutputPrompt jp-OutputArea-prompt">Out[ ]:</div>
8087
+ <div class="jp-RenderedText jp-OutputArea-output jp-OutputArea-executeResult" data-mime-type="text/plain" tabindex="0">
8088
+ <pre>{'P(Hobbies_and_Leisure)': 0.0,
8089
+ 'P(News)': 0.0,
8090
+ 'P(Science)': 0.0,
8091
+ 'P(Autos_and_Vehicles)': 0.0,
8092
+ 'P(Health)': 0.0,
8093
+ 'P(Pets_and_Animals)': 0.0,
8094
+ 'P(Adult)': 0.0,
8095
+ 'P(Computers_and_Electronics)': 1.0,
8096
+ 'P(Online Communities)': 0.0,
8097
+ 'P(Beauty_and_Fitness)': 0.0,
8098
+ 'P(People_and_Society)': 0.0,
8099
+ 'P(Business_and_Industrial)': 0.0,
8100
+ 'P(Reference)': 0.0,
8101
+ 'P(Shopping)': 0.0,
8102
+ 'P(Travel_and_Transportation)': 0.0,
8103
+ 'P(Food_and_Drink)': 0.0,
8104
+ 'P(Law_and_Government)': 0.0,
8105
+ 'P(Books_and_Literature)': 0.0,
8106
+ 'P(Finance)': 0.0,
8107
+ 'P(Games)': 0.0,
8108
+ 'P(Home_and_Garden)': 0.0,
8109
+ 'P(Jobs_and_Education)': 0.0,
8110
+ 'P(Arts_and_Entertainment)': 0.0,
8111
+ 'P(Sensitive Subjects)': 0.0,
8112
+ 'P(Real Estate)': 0.0,
8113
+ 'P(Internet_and_Telecom)': 0.0,
8114
+ 'P(Sports)': 0.0}</pre>
8115
+ </div>
8116
+ </div>
8117
+ </div>
8118
+ </div>
8119
+ </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell">
8120
+ <div class="jp-Cell-inputWrapper" tabindex="0">
8121
+ <div class="jp-Collapser jp-InputCollapser jp-Cell-inputCollapser">
8122
+ </div>
8123
+ <div class="jp-InputArea jp-Cell-inputArea">
8124
+ <div class="jp-InputPrompt jp-InputArea-prompt">In [ ]:</div>
8125
+ <div class="jp-CodeMirrorEditor jp-Editor jp-InputArea-editor" data-type="inline">
8126
+ <div class="cm-editor cm-s-jupyter">
8127
+ <div class="highlight hl-ipython3"><pre><span></span><span class="n">predict</span><span class="p">(</span>
8128
+ <span class="s1">'razer kraken'</span>
8129
+ <span class="p">)</span>
8130
+ </pre></div>
8131
+ </div>
8132
+ </div>
8133
+ </div>
8134
+ </div>
8135
+ <div class="jp-Cell-outputWrapper">
8136
+ <div class="jp-Collapser jp-OutputCollapser jp-Cell-outputCollapser">
8137
+ </div>
8138
+ <div class="jp-OutputArea jp-Cell-outputArea">
8139
+ <div class="jp-OutputArea-child">
8140
+ <div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
8141
+ <div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain" tabindex="0">
8142
+ <pre>Predicted Class: Computers_and_Electronics
8143
+ probabilities_scores: 0.9997072815895081
8144
+
8145
+ </pre>
8146
+ </div>
8147
+ </div>
8148
+ <div class="jp-OutputArea-child jp-OutputArea-executeResult">
8149
+ <div class="jp-OutputPrompt jp-OutputArea-prompt">Out[ ]:</div>
8150
+ <div class="jp-RenderedText jp-OutputArea-output jp-OutputArea-executeResult" data-mime-type="text/plain" tabindex="0">
8151
+ <pre>{'P(Hobbies_and_Leisure)': 0.0,
8152
+ 'P(News)': 0.0,
8153
+ 'P(Science)': 0.0,
8154
+ 'P(Autos_and_Vehicles)': 0.0,
8155
+ 'P(Health)': 0.0,
8156
+ 'P(Pets_and_Animals)': 0.0,
8157
+ 'P(Adult)': 0.0,
8158
+ 'P(Computers_and_Electronics)': 1.0,
8159
+ 'P(Online Communities)': 0.0,
8160
+ 'P(Beauty_and_Fitness)': 0.0,
8161
+ 'P(People_and_Society)': 0.0,
8162
+ 'P(Business_and_Industrial)': 0.0,
8163
+ 'P(Reference)': 0.0,
8164
+ 'P(Shopping)': 0.0,
8165
+ 'P(Travel_and_Transportation)': 0.0,
8166
+ 'P(Food_and_Drink)': 0.0,
8167
+ 'P(Law_and_Government)': 0.0,
8168
+ 'P(Books_and_Literature)': 0.0,
8169
+ 'P(Finance)': 0.0,
8170
+ 'P(Games)': 0.0,
8171
+ 'P(Home_and_Garden)': 0.0,
8172
+ 'P(Jobs_and_Education)': 0.0,
8173
+ 'P(Arts_and_Entertainment)': 0.0,
8174
+ 'P(Sensitive Subjects)': 0.0,
8175
+ 'P(Real Estate)': 0.0,
8176
+ 'P(Internet_and_Telecom)': 0.0,
8177
+ 'P(Sports)': 0.0}</pre>
8178
+ </div>
8179
+ </div>
8180
+ </div>
8181
+ </div>
8182
+ </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell">
8183
+ <div class="jp-Cell-inputWrapper" tabindex="0">
8184
+ <div class="jp-Collapser jp-InputCollapser jp-Cell-inputCollapser">
8185
+ </div>
8186
+ <div class="jp-InputArea jp-Cell-inputArea">
8187
+ <div class="jp-InputPrompt jp-InputArea-prompt">In [ ]:</div>
8188
+ <div class="jp-CodeMirrorEditor jp-Editor jp-InputArea-editor" data-type="inline">
8189
+ <div class="cm-editor cm-s-jupyter">
8190
+ <div class="highlight hl-ipython3"><pre><span></span><span class="n">predict</span><span class="p">(</span><span class="s2">"facebook"</span><span class="p">)</span>
8191
+ </pre></div>
8192
+ </div>
8193
+ </div>
8194
+ </div>
8195
+ </div>
8196
+ <div class="jp-Cell-outputWrapper">
8197
+ <div class="jp-Collapser jp-OutputCollapser jp-Cell-outputCollapser">
8198
+ </div>
8199
+ <div class="jp-OutputArea jp-Cell-outputArea">
8200
+ <div class="jp-OutputArea-child">
8201
+ <div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
8202
+ <div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain" tabindex="0">
8203
+ <pre>Predicted Class: Online Communities
8204
+ probabilities_scores: 0.997126042842865
8205
+
8206
+ </pre>
8207
+ </div>
8208
+ </div>
8209
+ <div class="jp-OutputArea-child jp-OutputArea-executeResult">
8210
+ <div class="jp-OutputPrompt jp-OutputArea-prompt">Out[ ]:</div>
8211
+ <div class="jp-RenderedText jp-OutputArea-output jp-OutputArea-executeResult" data-mime-type="text/plain" tabindex="0">
8212
+ <pre>{'P(Hobbies_and_Leisure)': 0.0,
8213
+ 'P(News)': 0.0,
8214
+ 'P(Science)': 0.0,
8215
+ 'P(Autos_and_Vehicles)': 0.0,
8216
+ 'P(Health)': 0.0,
8217
+ 'P(Pets_and_Animals)': 0.0,
8218
+ 'P(Adult)': 0.0,
8219
+ 'P(Computers_and_Electronics)': 0.001,
8220
+ 'P(Online Communities)': 0.997,
8221
+ 'P(Beauty_and_Fitness)': 0.0,
8222
+ 'P(People_and_Society)': 0.0,
8223
+ 'P(Business_and_Industrial)': 0.0,
8224
+ 'P(Reference)': 0.0,
8225
+ 'P(Shopping)': 0.0,
8226
+ 'P(Travel_and_Transportation)': 0.0,
8227
+ 'P(Food_and_Drink)': 0.0,
8228
+ 'P(Law_and_Government)': 0.0,
8229
+ 'P(Books_and_Literature)': 0.0,
8230
+ 'P(Finance)': 0.0,
8231
+ 'P(Games)': 0.0,
8232
+ 'P(Home_and_Garden)': 0.001,
8233
+ 'P(Jobs_and_Education)': 0.0,
8234
+ 'P(Arts_and_Entertainment)': 0.0,
8235
+ 'P(Sensitive Subjects)': 0.0,
8236
+ 'P(Real Estate)': 0.0,
8237
+ 'P(Internet_and_Telecom)': 0.0,
8238
+ 'P(Sports)': 0.0}</pre>
8239
+ </div>
8240
+ </div>
8241
+ </div>
8242
+ </div>
8243
+ </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell">
8244
+ <div class="jp-Cell-inputWrapper" tabindex="0">
8245
+ <div class="jp-Collapser jp-InputCollapser jp-Cell-inputCollapser">
8246
+ </div>
8247
+ <div class="jp-InputArea jp-Cell-inputArea">
8248
+ <div class="jp-InputPrompt jp-InputArea-prompt">In [ ]:</div>
8249
+ <div class="jp-CodeMirrorEditor jp-Editor jp-InputArea-editor" data-type="inline">
8250
+ <div class="cm-editor cm-s-jupyter">
8251
+ <div class="highlight hl-ipython3"><pre><span></span><span class="n">predict</span><span class="p">(</span><span class="s1">'apple iphone'</span><span class="p">)</span>
8252
+ </pre></div>
8253
+ </div>
8254
+ </div>
8255
+ </div>
8256
+ </div>
8257
+ <div class="jp-Cell-outputWrapper">
8258
+ <div class="jp-Collapser jp-OutputCollapser jp-Cell-outputCollapser">
8259
+ </div>
8260
+ <div class="jp-OutputArea jp-Cell-outputArea">
8261
+ <div class="jp-OutputArea-child">
8262
+ <div class="jp-OutputPrompt jp-OutputArea-prompt"></div>
8263
+ <div class="jp-RenderedText jp-OutputArea-output" data-mime-type="text/plain" tabindex="0">
8264
+ <pre>Predicted Class: Computers_and_Electronics
8265
+ probabilities_scores: 0.9997270703315735
8266
+
8267
+ </pre>
8268
+ </div>
8269
+ </div>
8270
+ <div class="jp-OutputArea-child jp-OutputArea-executeResult">
8271
+ <div class="jp-OutputPrompt jp-OutputArea-prompt">Out[ ]:</div>
8272
+ <div class="jp-RenderedText jp-OutputArea-output jp-OutputArea-executeResult" data-mime-type="text/plain" tabindex="0">
8273
+ <pre>{'P(Hobbies_and_Leisure)': 0.0,
8274
+ 'P(News)': 0.0,
8275
+ 'P(Science)': 0.0,
8276
+ 'P(Autos_and_Vehicles)': 0.0,
8277
+ 'P(Health)': 0.0,
8278
+ 'P(Pets_and_Animals)': 0.0,
8279
+ 'P(Adult)': 0.0,
8280
+ 'P(Computers_and_Electronics)': 1.0,
8281
+ 'P(Online Communities)': 0.0,
8282
+ 'P(Beauty_and_Fitness)': 0.0,
8283
+ 'P(People_and_Society)': 0.0,
8284
+ 'P(Business_and_Industrial)': 0.0,
8285
+ 'P(Reference)': 0.0,
8286
+ 'P(Shopping)': 0.0,
8287
+ 'P(Travel_and_Transportation)': 0.0,
8288
+ 'P(Food_and_Drink)': 0.0,
8289
+ 'P(Law_and_Government)': 0.0,
8290
+ 'P(Books_and_Literature)': 0.0,
8291
+ 'P(Finance)': 0.0,
8292
+ 'P(Games)': 0.0,
8293
+ 'P(Home_and_Garden)': 0.0,
8294
+ 'P(Jobs_and_Education)': 0.0,
8295
+ 'P(Arts_and_Entertainment)': 0.0,
8296
+ 'P(Sensitive Subjects)': 0.0,
8297
+ 'P(Real Estate)': 0.0,
8298
+ 'P(Internet_and_Telecom)': 0.0,
8299
+ 'P(Sports)': 0.0}</pre>
8300
+ </div>
8301
+ </div>
8302
  </div>
8303
  </div>
8304
  </div><div class="jp-Cell jp-CodeCell jp-Notebook-cell jp-mod-noOutputs">
research/09_inference.ipynb CHANGED
@@ -98,9 +98,17 @@
98
  },
99
  {
100
  "cell_type": "code",
101
- "execution_count": 4,
102
  "metadata": {},
103
  "outputs": [
 
 
 
 
 
 
 
 
104
  {
105
  "name": "stderr",
106
  "output_type": "stream",
@@ -114,9 +122,11 @@
114
  "from transformers import AutoModelForSequenceClassification\n",
115
  "import torch\n",
116
  "from torch.nn import functional as F\n",
 
117
  "\n",
118
  "\n",
119
- "model_name= \"finetuned_entity_categorical_classification/checkpoint-3184\"\n",
 
120
  "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
121
  "\n",
122
  "model = AutoModelForSequenceClassification.from_pretrained(model_name)\n"
@@ -124,10 +134,53 @@
124
  },
125
  {
126
  "cell_type": "code",
127
- "execution_count": 5,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  "metadata": {},
129
  "outputs": [],
130
  "source": [
 
 
 
 
 
 
 
 
131
  "\n",
132
  "\n",
133
  "def predict(sentence: str):\n",
@@ -140,16 +193,30 @@
140
  " \n",
141
  " # print(\"logits: \", logits)\n",
142
  " predicted_class_id = logits.argmax().item()\n",
 
143
  " # get probabilities using softmax from logit score and convert it to numpy array\n",
144
  " probabilities_scores = F.softmax(logits, dim = -1).numpy()[0]\n",
 
 
 
145
  " d= {}\n",
 
 
146
  " for i in range(27):\n",
147
  " # print(f\"P({id2label[i]}): {probabilities_scores[i]}\")\n",
148
- " d[f'P({id2label[i]})']= format(probabilities_scores[i], '.2f')\n",
 
 
 
 
 
 
 
 
149
  " \n",
150
  "\n",
151
- " print(\"Predicted Class: \", model.config.id2label[predicted_class_id], f\"probabilities_scores: {probabilities_scores[predicted_class_id]}\")\n",
152
- " return d\n",
153
  " \n",
154
  " \n",
155
  " "
@@ -157,42 +224,53 @@
157
  },
158
  {
159
  "cell_type": "code",
160
- "execution_count": 6,
161
  "metadata": {},
162
  "outputs": [
163
  {
164
  "name": "stdout",
165
  "output_type": "stream",
166
  "text": [
167
- "P(Beauty_and_Fitness): 1.0167686014028732e-05\n",
168
- "P(People_and_Society): 1.406734668307763e-06\n",
169
- "P(Travel_and_Transportation): 9.111173540077289e-07\n",
170
- "P(Shopping): 2.7279720598016866e-05\n",
171
- "P(Adult): 2.7205089736526133e-06\n",
172
- "P(Sports): 2.7785404199676123e-06\n",
173
- "P(Science): 9.693985703052022e-07\n",
174
- "P(Food_and_Drink): 5.907952072448097e-06\n",
175
- "P(News): 8.620731023256667e-06\n",
176
- "P(Sensitive Subjects): 2.1766395548183937e-06\n",
177
- "P(Autos_and_Vehicles): 3.173354627961089e-07\n",
178
- "P(Law_and_Government): 1.089682882593479e-06\n",
179
- "P(Business_and_Industrial): 2.0000404674647143e-06\n",
180
- "P(Health): 8.528571925126016e-06\n",
181
- "P(Real Estate): 6.72997032324929e-07\n",
182
- "P(Books_and_Literature): 1.7418132074453752e-06\n",
183
- "P(Computers_and_Electronics): 0.9998340606689453\n",
184
- "P(Internet_and_Telecom): 4.2605301132425666e-05\n",
185
- "P(Home_and_Garden): 7.0778082772449125e-06\n",
186
- "P(Jobs_and_Education): 3.205217353752232e-07\n",
187
- "P(Online Communities): 7.534316409874009e-06\n",
188
- "P(Finance): 3.597612248995574e-06\n",
189
- "P(Arts_and_Entertainment): 1.5469729532924248e-06\n",
190
- "P(Games): 2.201926508860197e-05\n",
191
- "P(Hobbies_and_Leisure): 2.3530192265752703e-06\n",
192
- "P(Reference): 2.341075600043041e-08\n",
193
- "P(Pets_and_Animals): 1.5077214357006596e-06\n",
194
- "Predicted Class: Computers_and_Electronics probabilities_scores: 0.9998340606689453\n"
195
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  }
197
  ],
198
  "source": [
@@ -201,42 +279,53 @@
201
  },
202
  {
203
  "cell_type": "code",
204
- "execution_count": 7,
205
  "metadata": {},
206
  "outputs": [
207
  {
208
  "name": "stdout",
209
  "output_type": "stream",
210
  "text": [
211
- "P(Beauty_and_Fitness): 0.0002981989237014204\n",
212
- "P(People_and_Society): 1.8243508748128079e-06\n",
213
- "P(Travel_and_Transportation): 1.4317002751340624e-05\n",
214
- "P(Shopping): 9.405774108017795e-06\n",
215
- "P(Adult): 1.2231478194735246e-06\n",
216
- "P(Sports): 6.019924967404222e-06\n",
217
- "P(Science): 7.067929800541606e-06\n",
218
- "P(Food_and_Drink): 0.9972833395004272\n",
219
- "P(News): 0.00014127693430054933\n",
220
- "P(Sensitive Subjects): 2.4317660063388757e-06\n",
221
- "P(Autos_and_Vehicles): 5.870697918908263e-07\n",
222
- "P(Law_and_Government): 3.3484843697806355e-06\n",
223
- "P(Business_and_Industrial): 5.084546046418836e-06\n",
224
- "P(Health): 0.0021307284478098154\n",
225
- "P(Real Estate): 1.483008531977248e-06\n",
226
- "P(Books_and_Literature): 2.4371431663894327e-06\n",
227
- "P(Computers_and_Electronics): 1.0735298928921111e-05\n",
228
- "P(Internet_and_Telecom): 2.851840008588624e-06\n",
229
- "P(Home_and_Garden): 2.7712192149920156e-06\n",
230
- "P(Jobs_and_Education): 1.1146977158205118e-05\n",
231
- "P(Online Communities): 7.0186338234634604e-06\n",
232
- "P(Finance): 5.121751655678963e-06\n",
233
- "P(Arts_and_Entertainment): 8.403771062148735e-06\n",
234
- "P(Games): 2.9928612548246747e-06\n",
235
- "P(Hobbies_and_Leisure): 3.484110129647888e-05\n",
236
- "P(Reference): 6.697590748672155e-08\n",
237
- "P(Pets_and_Animals): 5.252794835541863e-06\n",
238
- "Predicted Class: Food_and_Drink probabilities_scores: 0.9972833395004272\n"
239
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  }
241
  ],
242
  "source": [
@@ -245,42 +334,53 @@
245
  },
246
  {
247
  "cell_type": "code",
248
- "execution_count": 8,
249
  "metadata": {},
250
  "outputs": [
251
  {
252
  "name": "stdout",
253
  "output_type": "stream",
254
  "text": [
255
- "P(Beauty_and_Fitness): 2.6114428692380898e-05\n",
256
- "P(People_and_Society): 6.279856279434171e-07\n",
257
- "P(Travel_and_Transportation): 6.017768100718968e-06\n",
258
- "P(Shopping): 6.115729320299579e-06\n",
259
- "P(Adult): 4.621779794433678e-07\n",
260
- "P(Sports): 8.989664479486237e-07\n",
261
- "P(Science): 4.8601555135974195e-06\n",
262
- "P(Food_and_Drink): 0.9997175335884094\n",
263
- "P(News): 0.00015670375432819128\n",
264
- "P(Sensitive Subjects): 5.142674694980087e-07\n",
265
- "P(Autos_and_Vehicles): 2.1764762436760066e-07\n",
266
- "P(Law_and_Government): 1.2030991456413176e-06\n",
267
- "P(Business_and_Industrial): 1.6263313682429725e-06\n",
268
- "P(Health): 4.478434129850939e-05\n",
269
- "P(Real Estate): 6.337517106658197e-07\n",
270
- "P(Books_and_Literature): 1.2728096407954581e-06\n",
271
- "P(Computers_and_Electronics): 2.8549591206683544e-06\n",
272
- "P(Internet_and_Telecom): 1.3799519820167916e-06\n",
273
- "P(Home_and_Garden): 2.937797489721561e-06\n",
274
- "P(Jobs_and_Education): 4.768957296619192e-06\n",
275
- "P(Online Communities): 2.587612470961176e-06\n",
276
- "P(Finance): 1.5463368754353723e-06\n",
277
- "P(Arts_and_Entertainment): 6.821313945692964e-06\n",
278
- "P(Games): 7.65006177516625e-07\n",
279
- "P(Hobbies_and_Leisure): 4.179368261247873e-06\n",
280
- "P(Reference): 3.270602633165254e-08\n",
281
- "P(Pets_and_Animals): 2.580756472525536e-06\n",
282
- "Predicted Class: Food_and_Drink probabilities_scores: 0.9997175335884094\n"
283
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  }
285
  ],
286
  "source": [
@@ -289,42 +389,53 @@
289
  },
290
  {
291
  "cell_type": "code",
292
- "execution_count": 9,
293
  "metadata": {},
294
  "outputs": [
295
  {
296
  "name": "stdout",
297
  "output_type": "stream",
298
  "text": [
299
- "P(Beauty_and_Fitness): 6.976195891184034e-06\n",
300
- "P(People_and_Society): 1.2303950143177644e-06\n",
301
- "P(Travel_and_Transportation): 1.7862849972516415e-06\n",
302
- "P(Shopping): 5.573031558014918e-06\n",
303
- "P(Adult): 3.2791076591820456e-06\n",
304
- "P(Sports): 5.794179287477164e-06\n",
305
- "P(Science): 8.48299987410428e-06\n",
306
- "P(Food_and_Drink): 0.0005717862513847649\n",
307
- "P(News): 1.0014691724791192e-05\n",
308
- "P(Sensitive Subjects): 2.9312270726222778e-06\n",
309
- "P(Autos_and_Vehicles): 1.5730682889625314e-07\n",
310
- "P(Law_and_Government): 1.0351266155339545e-06\n",
311
- "P(Business_and_Industrial): 1.9998137759102974e-06\n",
312
- "P(Health): 5.863273599970853e-06\n",
313
- "P(Real Estate): 2.589280256870552e-07\n",
314
- "P(Books_and_Literature): 3.1806489459995646e-06\n",
315
- "P(Computers_and_Electronics): 1.6475665688631125e-05\n",
316
- "P(Internet_and_Telecom): 1.3075596143607982e-06\n",
317
- "P(Home_and_Garden): 1.027156031341292e-05\n",
318
- "P(Jobs_and_Education): 1.03862419109646e-06\n",
319
- "P(Online Communities): 4.737964445666876e-06\n",
320
- "P(Finance): 2.0996037619624985e-06\n",
321
- "P(Arts_and_Entertainment): 4.993361471861135e-06\n",
322
- "P(Games): 4.1619005060056224e-06\n",
323
- "P(Hobbies_and_Leisure): 1.088273165805731e-05\n",
324
- "P(Reference): 6.112716022244058e-08\n",
325
- "P(Pets_and_Animals): 0.9993135929107666\n",
326
- "Predicted Class: Pets_and_Animals probabilities_scores: 0.9993135929107666\n"
327
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  }
329
  ],
330
  "source": [
@@ -333,42 +444,53 @@
333
  },
334
  {
335
  "cell_type": "code",
336
- "execution_count": 10,
337
  "metadata": {},
338
  "outputs": [
339
  {
340
  "name": "stdout",
341
  "output_type": "stream",
342
  "text": [
343
- "P(Beauty_and_Fitness): 1.0418082638352644e-05\n",
344
- "P(People_and_Society): 1.198376025968173e-06\n",
345
- "P(Travel_and_Transportation): 5.249040100352431e-07\n",
346
- "P(Shopping): 1.6788271750556305e-05\n",
347
- "P(Adult): 2.3851741843827767e-06\n",
348
- "P(Sports): 1.8478541505828616e-06\n",
349
- "P(Science): 8.450400628134958e-07\n",
350
- "P(Food_and_Drink): 3.6571536838891916e-06\n",
351
- "P(News): 4.5494271034840494e-06\n",
352
- "P(Sensitive Subjects): 2.1925256987742614e-06\n",
353
- "P(Autos_and_Vehicles): 2.598584387669689e-07\n",
354
- "P(Law_and_Government): 9.124052553488582e-07\n",
355
- "P(Business_and_Industrial): 1.343827193522884e-06\n",
356
- "P(Health): 7.631779226358049e-06\n",
357
- "P(Real Estate): 4.913577527076995e-07\n",
358
- "P(Books_and_Literature): 1.6118407302201376e-06\n",
359
- "P(Computers_and_Electronics): 0.9998828172683716\n",
360
- "P(Internet_and_Telecom): 2.9297894798219204e-05\n",
361
- "P(Home_and_Garden): 5.192091521166731e-06\n",
362
- "P(Jobs_and_Education): 2.745251777014346e-07\n",
363
- "P(Online Communities): 6.218880571395857e-06\n",
364
- "P(Finance): 3.290834229119355e-06\n",
365
- "P(Arts_and_Entertainment): 1.541877054478391e-06\n",
366
- "P(Games): 1.1492516023281496e-05\n",
367
- "P(Hobbies_and_Leisure): 1.9986127881566063e-06\n",
368
- "P(Reference): 1.8265923884541735e-08\n",
369
- "P(Pets_and_Animals): 1.1247184374951757e-06\n",
370
- "Predicted Class: Computers_and_Electronics probabilities_scores: 0.9998828172683716\n"
371
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  }
373
  ],
374
  "source": [
@@ -377,42 +499,53 @@
377
  },
378
  {
379
  "cell_type": "code",
380
- "execution_count": 11,
381
  "metadata": {},
382
  "outputs": [
383
  {
384
  "name": "stdout",
385
  "output_type": "stream",
386
  "text": [
387
- "P(Beauty_and_Fitness): 1.086300744645996e-05\n",
388
- "P(People_and_Society): 2.385743300692411e-07\n",
389
- "P(Travel_and_Transportation): 1.9932767827413045e-06\n",
390
- "P(Shopping): 4.334059667598922e-06\n",
391
- "P(Adult): 3.253454110563325e-07\n",
392
- "P(Sports): 8.683252303853806e-07\n",
393
- "P(Science): 2.3967959350557067e-06\n",
394
- "P(Food_and_Drink): 0.9998577833175659\n",
395
- "P(News): 5.469225288834423e-05\n",
396
- "P(Sensitive Subjects): 3.331420828089904e-07\n",
397
- "P(Autos_and_Vehicles): 1.0676290429501023e-07\n",
398
- "P(Law_and_Government): 4.7278643933168496e-07\n",
399
- "P(Business_and_Industrial): 1.5407667888212018e-06\n",
400
- "P(Health): 4.193164568278007e-05\n",
401
- "P(Real Estate): 3.750056123408285e-07\n",
402
- "P(Books_and_Literature): 4.987622901353461e-07\n",
403
- "P(Computers_and_Electronics): 3.906153779098531e-06\n",
404
- "P(Internet_and_Telecom): 8.262347819254501e-07\n",
405
- "P(Home_and_Garden): 1.5766403294037445e-06\n",
406
- "P(Jobs_and_Education): 4.150041149841854e-06\n",
407
- "P(Online Communities): 2.0979061901016394e-06\n",
408
- "P(Finance): 1.1580733598748338e-06\n",
409
- "P(Arts_and_Entertainment): 2.0028785456815967e-06\n",
410
- "P(Games): 9.470307986703119e-07\n",
411
- "P(Hobbies_and_Leisure): 2.5496683520032093e-06\n",
412
- "P(Reference): 1.3998636916312535e-08\n",
413
- "P(Pets_and_Animals): 1.9844153484882554e-06\n",
414
- "Predicted Class: Food_and_Drink probabilities_scores: 0.9998577833175659\n"
415
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
  }
417
  ],
418
  "source": [
@@ -421,42 +554,53 @@
421
  },
422
  {
423
  "cell_type": "code",
424
- "execution_count": 12,
425
  "metadata": {},
426
  "outputs": [
427
  {
428
  "name": "stdout",
429
  "output_type": "stream",
430
  "text": [
431
- "P(Beauty_and_Fitness): 0.00013269745977595448\n",
432
- "P(People_and_Society): 4.455394901015097e-06\n",
433
- "P(Travel_and_Transportation): 2.5948824259103276e-05\n",
434
- "P(Shopping): 0.0005248919478617609\n",
435
- "P(Adult): 1.7862246750155464e-05\n",
436
- "P(Sports): 1.6017889720387757e-05\n",
437
- "P(Science): 2.5951496354537085e-05\n",
438
- "P(Food_and_Drink): 0.9478479623794556\n",
439
- "P(News): 0.0002582172746770084\n",
440
- "P(Sensitive Subjects): 1.79517828655662e-05\n",
441
- "P(Autos_and_Vehicles): 4.965268544765422e-06\n",
442
- "P(Law_and_Government): 7.921374162833672e-06\n",
443
- "P(Business_and_Industrial): 0.0001139482410508208\n",
444
- "P(Health): 0.0005791003350168467\n",
445
- "P(Real Estate): 6.392176146619022e-06\n",
446
- "P(Books_and_Literature): 2.4286606276291423e-05\n",
447
- "P(Computers_and_Electronics): 0.049869947135448456\n",
448
- "P(Internet_and_Telecom): 9.170828707283363e-05\n",
449
- "P(Home_and_Garden): 9.513090481050313e-05\n",
450
- "P(Jobs_and_Education): 3.3369826269336045e-05\n",
451
- "P(Online Communities): 8.171715307980776e-05\n",
452
- "P(Finance): 3.625190947786905e-05\n",
453
- "P(Arts_and_Entertainment): 2.533747101551853e-05\n",
454
- "P(Games): 8.59149222378619e-05\n",
455
- "P(Hobbies_and_Leisure): 2.0291698092478327e-05\n",
456
- "P(Reference): 1.9418187946484977e-07\n",
457
- "P(Pets_and_Animals): 5.1680701290024444e-05\n",
458
- "Predicted Class: Food_and_Drink probabilities_scores: 0.9478479623794556\n"
459
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  }
461
  ],
462
  "source": [
@@ -465,22 +609,468 @@
465
  },
466
  {
467
  "cell_type": "code",
468
- "execution_count": null,
469
  "metadata": {},
470
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  "source": [
472
  "predict(\n",
473
  " 'razer kraken'\n",
474
  ")"
475
  ]
476
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  {
478
  "cell_type": "code",
479
  "execution_count": null,
480
  "metadata": {},
481
  "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482
  "source": [
483
- "predict(\"facebook\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
  ]
485
  },
486
  {
 
98
  },
99
  {
100
  "cell_type": "code",
101
+ "execution_count": 3,
102
  "metadata": {},
103
  "outputs": [
104
+ {
105
+ "name": "stderr",
106
+ "output_type": "stream",
107
+ "text": [
108
+ "/home/ubuntu/SentenceStructureComparision/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
109
+ " from .autonotebook import tqdm as notebook_tqdm\n"
110
+ ]
111
+ },
112
  {
113
  "name": "stderr",
114
  "output_type": "stream",
 
122
  "from transformers import AutoModelForSequenceClassification\n",
123
  "import torch\n",
124
  "from torch.nn import functional as F\n",
125
+ "import numpy as np\n",
126
  "\n",
127
  "\n",
128
+ "\n",
129
+ "model_name= \"finetuned_entity_categorical_classification/checkpoint-3212\"\n",
130
  "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
131
  "\n",
132
  "model = AutoModelForSequenceClassification.from_pretrained(model_name)\n"
 
134
  },
135
  {
136
  "cell_type": "code",
137
+ "execution_count": null,
138
+ "metadata": {},
139
+ "outputs": [],
140
+ "source": []
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": null,
145
+ "metadata": {},
146
+ "outputs": [],
147
+ "source": []
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": null,
152
+ "metadata": {},
153
+ "outputs": [],
154
+ "source": []
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": null,
159
+ "metadata": {},
160
+ "outputs": [],
161
+ "source": []
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": null,
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": []
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": 50,
173
  "metadata": {},
174
  "outputs": [],
175
  "source": [
176
+ "# probabilities = 1 / (1 + np.exp(-logit_score))\n",
177
+ "def logit2prob(logit):\n",
178
+ " # odds =np.exp(logit)\n",
179
+ " # prob = odds / (1 + odds)\n",
180
+ " prob= 1/(1+ np.exp(-logit))\n",
181
+ " return np.round(prob, 3)\n",
182
+ "\n",
183
+ "\n",
184
  "\n",
185
  "\n",
186
  "def predict(sentence: str):\n",
 
193
  " \n",
194
  " # print(\"logits: \", logits)\n",
195
  " predicted_class_id = logits.argmax().item()\n",
196
+ " \n",
197
  " # get probabilities using softmax from logit score and convert it to numpy array\n",
198
  " probabilities_scores = F.softmax(logits, dim = -1).numpy()[0]\n",
199
+ " individual_probabilities_scores = logit2prob(logits.numpy()[0])\n",
200
+ " \n",
201
+ " \n",
202
  " d= {}\n",
203
+ " d_ind= {}\n",
204
+ " # d_ind= {}\n",
205
  " for i in range(27):\n",
206
  " # print(f\"P({id2label[i]}): {probabilities_scores[i]}\")\n",
207
+ " # d[f'P({id2label[i]})']= format(probabilities_scores[i], '.2f')\n",
208
+ " d[f'P({id2label[i]})']= round(probabilities_scores[i], 3)\n",
209
+ " \n",
210
+ " \n",
211
+ " for i in range(27):\n",
212
+ " # print(f\"P({id2label[i]}): {probabilities_scores[i]}\")\n",
213
+ " # d[f'P({id2label[i]})']= format(probabilities_scores[i], '.2f')\n",
214
+ " d_ind[f'P({id2label[i]})']= (individual_probabilities_scores[i])\n",
215
+ " \n",
216
  " \n",
217
  "\n",
218
+ " print(\"Predicted Class: \", model.config.id2label[predicted_class_id], f\"\\nprobabilities_scores: {individual_probabilities_scores[predicted_class_id]}\\n\")\n",
219
+ " return d_ind\n",
220
  " \n",
221
  " \n",
222
  " "
 
224
  },
225
  {
226
  "cell_type": "code",
227
+ "execution_count": 51,
228
  "metadata": {},
229
  "outputs": [
230
  {
231
  "name": "stdout",
232
  "output_type": "stream",
233
  "text": [
234
+ "Predicted Class: Computers_and_Electronics \n",
235
+ "probabilities_scores: 1.0\n",
236
+ "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  ]
238
+ },
239
+ {
240
+ "data": {
241
+ "text/plain": [
242
+ "{'P(Hobbies_and_Leisure)': 0.107,\n",
243
+ " 'P(News)': 0.003,\n",
244
+ " 'P(Science)': 0.028,\n",
245
+ " 'P(Autos_and_Vehicles)': 0.083,\n",
246
+ " 'P(Health)': 0.011,\n",
247
+ " 'P(Pets_and_Animals)': 0.006,\n",
248
+ " 'P(Adult)': 0.093,\n",
249
+ " 'P(Computers_and_Electronics)': 1.0,\n",
250
+ " 'P(Online Communities)': 0.116,\n",
251
+ " 'P(Beauty_and_Fitness)': 0.015,\n",
252
+ " 'P(People_and_Society)': 0.0,\n",
253
+ " 'P(Business_and_Industrial)': 0.005,\n",
254
+ " 'P(Reference)': 0.037,\n",
255
+ " 'P(Shopping)': 0.158,\n",
256
+ " 'P(Travel_and_Transportation)': 0.005,\n",
257
+ " 'P(Food_and_Drink)': 0.032,\n",
258
+ " 'P(Law_and_Government)': 0.153,\n",
259
+ " 'P(Books_and_Literature)': 0.008,\n",
260
+ " 'P(Finance)': 0.041,\n",
261
+ " 'P(Games)': 0.063,\n",
262
+ " 'P(Home_and_Garden)': 0.028,\n",
263
+ " 'P(Jobs_and_Education)': 0.004,\n",
264
+ " 'P(Arts_and_Entertainment)': 0.011,\n",
265
+ " 'P(Sensitive Subjects)': 0.004,\n",
266
+ " 'P(Real Estate)': 0.014,\n",
267
+ " 'P(Internet_and_Telecom)': 0.019,\n",
268
+ " 'P(Sports)': 0.023}"
269
+ ]
270
+ },
271
+ "execution_count": 51,
272
+ "metadata": {},
273
+ "output_type": "execute_result"
274
  }
275
  ],
276
  "source": [
 
279
  },
280
  {
281
  "cell_type": "code",
282
+ "execution_count": 36,
283
  "metadata": {},
284
  "outputs": [
285
  {
286
  "name": "stdout",
287
  "output_type": "stream",
288
  "text": [
289
+ "Predicted Class: Food_and_Drink \n",
290
+ "probabilities_scores: 1.0\n",
291
+ "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  ]
293
+ },
294
+ {
295
+ "data": {
296
+ "text/plain": [
297
+ "{'P(Hobbies_and_Leisure)': 0.032,\n",
298
+ " 'P(News)': 0.167,\n",
299
+ " 'P(Science)': 0.019,\n",
300
+ " 'P(Autos_and_Vehicles)': 0.028,\n",
301
+ " 'P(Health)': 0.134,\n",
302
+ " 'P(Pets_and_Animals)': 0.004,\n",
303
+ " 'P(Adult)': 0.018,\n",
304
+ " 'P(Computers_and_Electronics)': 0.223,\n",
305
+ " 'P(Online Communities)': 0.169,\n",
306
+ " 'P(Beauty_and_Fitness)': 0.081,\n",
307
+ " 'P(People_and_Society)': 0.005,\n",
308
+ " 'P(Business_and_Industrial)': 0.011,\n",
309
+ " 'P(Reference)': 0.022,\n",
310
+ " 'P(Shopping)': 0.054,\n",
311
+ " 'P(Travel_and_Transportation)': 0.024,\n",
312
+ " 'P(Food_and_Drink)': 1.0,\n",
313
+ " 'P(Law_and_Government)': 0.016,\n",
314
+ " 'P(Books_and_Literature)': 0.066,\n",
315
+ " 'P(Finance)': 0.01,\n",
316
+ " 'P(Games)': 0.063,\n",
317
+ " 'P(Home_and_Garden)': 0.044,\n",
318
+ " 'P(Jobs_and_Education)': 0.033,\n",
319
+ " 'P(Arts_and_Entertainment)': 0.286,\n",
320
+ " 'P(Sensitive Subjects)': 0.032,\n",
321
+ " 'P(Real Estate)': 0.003,\n",
322
+ " 'P(Internet_and_Telecom)': 0.009,\n",
323
+ " 'P(Sports)': 0.016}"
324
+ ]
325
+ },
326
+ "execution_count": 36,
327
+ "metadata": {},
328
+ "output_type": "execute_result"
329
  }
330
  ],
331
  "source": [
 
334
  },
335
  {
336
  "cell_type": "code",
337
+ "execution_count": 37,
338
  "metadata": {},
339
  "outputs": [
340
  {
341
  "name": "stdout",
342
  "output_type": "stream",
343
  "text": [
344
+ "Predicted Class: Food_and_Drink \n",
345
+ "probabilities_scores: 1.0\n",
346
+ "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  ]
348
+ },
349
+ {
350
+ "data": {
351
+ "text/plain": [
352
+ "{'P(Hobbies_and_Leisure)': 0.048,\n",
353
+ " 'P(News)': 0.202,\n",
354
+ " 'P(Science)': 0.025,\n",
355
+ " 'P(Autos_and_Vehicles)': 0.095,\n",
356
+ " 'P(Health)': 0.094,\n",
357
+ " 'P(Pets_and_Animals)': 0.006,\n",
358
+ " 'P(Adult)': 0.016,\n",
359
+ " 'P(Computers_and_Electronics)': 0.129,\n",
360
+ " 'P(Online Communities)': 0.078,\n",
361
+ " 'P(Beauty_and_Fitness)': 0.122,\n",
362
+ " 'P(People_and_Society)': 0.008,\n",
363
+ " 'P(Business_and_Industrial)': 0.022,\n",
364
+ " 'P(Reference)': 0.014,\n",
365
+ " 'P(Shopping)': 0.046,\n",
366
+ " 'P(Travel_and_Transportation)': 0.024,\n",
367
+ " 'P(Food_and_Drink)': 1.0,\n",
368
+ " 'P(Law_and_Government)': 0.013,\n",
369
+ " 'P(Books_and_Literature)': 0.038,\n",
370
+ " 'P(Finance)': 0.026,\n",
371
+ " 'P(Games)': 0.091,\n",
372
+ " 'P(Home_and_Garden)': 0.025,\n",
373
+ " 'P(Jobs_and_Education)': 0.033,\n",
374
+ " 'P(Arts_and_Entertainment)': 0.233,\n",
375
+ " 'P(Sensitive Subjects)': 0.022,\n",
376
+ " 'P(Real Estate)': 0.005,\n",
377
+ " 'P(Internet_and_Telecom)': 0.003,\n",
378
+ " 'P(Sports)': 0.039}"
379
+ ]
380
+ },
381
+ "execution_count": 37,
382
+ "metadata": {},
383
+ "output_type": "execute_result"
384
  }
385
  ],
386
  "source": [
 
389
  },
390
  {
391
  "cell_type": "code",
392
+ "execution_count": 38,
393
  "metadata": {},
394
  "outputs": [
395
  {
396
  "name": "stdout",
397
  "output_type": "stream",
398
  "text": [
399
+ "Predicted Class: Food_and_Drink \n",
400
+ "probabilities_scores: 0.9980000257492065\n",
401
+ "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  ]
403
+ },
404
+ {
405
+ "data": {
406
+ "text/plain": [
407
+ "{'P(Hobbies_and_Leisure)': 0.113,\n",
408
+ " 'P(News)': 0.037,\n",
409
+ " 'P(Science)': 0.024,\n",
410
+ " 'P(Autos_and_Vehicles)': 0.05,\n",
411
+ " 'P(Health)': 0.039,\n",
412
+ " 'P(Pets_and_Animals)': 0.444,\n",
413
+ " 'P(Adult)': 0.003,\n",
414
+ " 'P(Computers_and_Electronics)': 0.022,\n",
415
+ " 'P(Online Communities)': 0.12,\n",
416
+ " 'P(Beauty_and_Fitness)': 0.114,\n",
417
+ " 'P(People_and_Society)': 0.001,\n",
418
+ " 'P(Business_and_Industrial)': 0.008,\n",
419
+ " 'P(Reference)': 0.003,\n",
420
+ " 'P(Shopping)': 0.014,\n",
421
+ " 'P(Travel_and_Transportation)': 0.009,\n",
422
+ " 'P(Food_and_Drink)': 0.998,\n",
423
+ " 'P(Law_and_Government)': 0.005,\n",
424
+ " 'P(Books_and_Literature)': 0.006,\n",
425
+ " 'P(Finance)': 0.009,\n",
426
+ " 'P(Games)': 0.052,\n",
427
+ " 'P(Home_and_Garden)': 0.006,\n",
428
+ " 'P(Jobs_and_Education)': 0.005,\n",
429
+ " 'P(Arts_and_Entertainment)': 0.199,\n",
430
+ " 'P(Sensitive Subjects)': 0.033,\n",
431
+ " 'P(Real Estate)': 0.003,\n",
432
+ " 'P(Internet_and_Telecom)': 0.001,\n",
433
+ " 'P(Sports)': 0.123}"
434
+ ]
435
+ },
436
+ "execution_count": 38,
437
+ "metadata": {},
438
+ "output_type": "execute_result"
439
  }
440
  ],
441
  "source": [
 
444
  },
445
  {
446
  "cell_type": "code",
447
+ "execution_count": 39,
448
  "metadata": {},
449
  "outputs": [
450
  {
451
  "name": "stdout",
452
  "output_type": "stream",
453
  "text": [
454
+ "Predicted Class: Computers_and_Electronics \n",
455
+ "probabilities_scores: 1.0\n",
456
+ "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  ]
458
+ },
459
+ {
460
+ "data": {
461
+ "text/plain": [
462
+ "{'P(Hobbies_and_Leisure)': 0.134,\n",
463
+ " 'P(News)': 0.002,\n",
464
+ " 'P(Science)': 0.027,\n",
465
+ " 'P(Autos_and_Vehicles)': 0.061,\n",
466
+ " 'P(Health)': 0.008,\n",
467
+ " 'P(Pets_and_Animals)': 0.006,\n",
468
+ " 'P(Adult)': 0.069,\n",
469
+ " 'P(Computers_and_Electronics)': 1.0,\n",
470
+ " 'P(Online Communities)': 0.16,\n",
471
+ " 'P(Beauty_and_Fitness)': 0.015,\n",
472
+ " 'P(People_and_Society)': 0.0,\n",
473
+ " 'P(Business_and_Industrial)': 0.003,\n",
474
+ " 'P(Reference)': 0.019,\n",
475
+ " 'P(Shopping)': 0.147,\n",
476
+ " 'P(Travel_and_Transportation)': 0.005,\n",
477
+ " 'P(Food_and_Drink)': 0.023,\n",
478
+ " 'P(Law_and_Government)': 0.115,\n",
479
+ " 'P(Books_and_Literature)': 0.007,\n",
480
+ " 'P(Finance)': 0.037,\n",
481
+ " 'P(Games)': 0.042,\n",
482
+ " 'P(Home_and_Garden)': 0.032,\n",
483
+ " 'P(Jobs_and_Education)': 0.003,\n",
484
+ " 'P(Arts_and_Entertainment)': 0.01,\n",
485
+ " 'P(Sensitive Subjects)': 0.003,\n",
486
+ " 'P(Real Estate)': 0.012,\n",
487
+ " 'P(Internet_and_Telecom)': 0.016,\n",
488
+ " 'P(Sports)': 0.015}"
489
+ ]
490
+ },
491
+ "execution_count": 39,
492
+ "metadata": {},
493
+ "output_type": "execute_result"
494
  }
495
  ],
496
  "source": [
 
499
  },
500
  {
501
  "cell_type": "code",
502
+ "execution_count": 40,
503
  "metadata": {},
504
  "outputs": [
505
  {
506
  "name": "stdout",
507
  "output_type": "stream",
508
  "text": [
509
+ "Predicted Class: Food_and_Drink \n",
510
+ "probabilities_scores: 0.9909999966621399\n",
511
+ "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
  ]
513
+ },
514
+ {
515
+ "data": {
516
+ "text/plain": [
517
+ "{'P(Hobbies_and_Leisure)': 0.02,\n",
518
+ " 'P(News)': 0.017,\n",
519
+ " 'P(Science)': 0.008,\n",
520
+ " 'P(Autos_and_Vehicles)': 0.06,\n",
521
+ " 'P(Health)': 0.032,\n",
522
+ " 'P(Pets_and_Animals)': 0.004,\n",
523
+ " 'P(Adult)': 0.022,\n",
524
+ " 'P(Computers_and_Electronics)': 0.989,\n",
525
+ " 'P(Online Communities)': 0.056,\n",
526
+ " 'P(Beauty_and_Fitness)': 0.026,\n",
527
+ " 'P(People_and_Society)': 0.0,\n",
528
+ " 'P(Business_and_Industrial)': 0.008,\n",
529
+ " 'P(Reference)': 0.052,\n",
530
+ " 'P(Shopping)': 0.105,\n",
531
+ " 'P(Travel_and_Transportation)': 0.012,\n",
532
+ " 'P(Food_and_Drink)': 0.991,\n",
533
+ " 'P(Law_and_Government)': 0.007,\n",
534
+ " 'P(Books_and_Literature)': 0.009,\n",
535
+ " 'P(Finance)': 0.014,\n",
536
+ " 'P(Games)': 0.284,\n",
537
+ " 'P(Home_and_Garden)': 0.015,\n",
538
+ " 'P(Jobs_and_Education)': 0.017,\n",
539
+ " 'P(Arts_and_Entertainment)': 0.031,\n",
540
+ " 'P(Sensitive Subjects)': 0.014,\n",
541
+ " 'P(Real Estate)': 0.003,\n",
542
+ " 'P(Internet_and_Telecom)': 0.003,\n",
543
+ " 'P(Sports)': 0.021}"
544
+ ]
545
+ },
546
+ "execution_count": 40,
547
+ "metadata": {},
548
+ "output_type": "execute_result"
549
  }
550
  ],
551
  "source": [
 
554
  },
555
  {
556
  "cell_type": "code",
557
+ "execution_count": 41,
558
  "metadata": {},
559
  "outputs": [
560
  {
561
  "name": "stdout",
562
  "output_type": "stream",
563
  "text": [
564
+ "Predicted Class: Computers_and_Electronics \n",
565
+ "probabilities_scores: 1.0\n",
566
+ "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
567
  ]
568
+ },
569
+ {
570
+ "data": {
571
+ "text/plain": [
572
+ "{'P(Hobbies_and_Leisure)': 0.054,\n",
573
+ " 'P(News)': 0.003,\n",
574
+ " 'P(Science)': 0.011,\n",
575
+ " 'P(Autos_and_Vehicles)': 0.122,\n",
576
+ " 'P(Health)': 0.01,\n",
577
+ " 'P(Pets_and_Animals)': 0.004,\n",
578
+ " 'P(Adult)': 0.054,\n",
579
+ " 'P(Computers_and_Electronics)': 1.0,\n",
580
+ " 'P(Online Communities)': 0.081,\n",
581
+ " 'P(Beauty_and_Fitness)': 0.016,\n",
582
+ " 'P(People_and_Society)': 0.0,\n",
583
+ " 'P(Business_and_Industrial)': 0.005,\n",
584
+ " 'P(Reference)': 0.064,\n",
585
+ " 'P(Shopping)': 0.224,\n",
586
+ " 'P(Travel_and_Transportation)': 0.006,\n",
587
+ " 'P(Food_and_Drink)': 0.172,\n",
588
+ " 'P(Law_and_Government)': 0.051,\n",
589
+ " 'P(Books_and_Literature)': 0.006,\n",
590
+ " 'P(Finance)': 0.025,\n",
591
+ " 'P(Games)': 0.138,\n",
592
+ " 'P(Home_and_Garden)': 0.03,\n",
593
+ " 'P(Jobs_and_Education)': 0.006,\n",
594
+ " 'P(Arts_and_Entertainment)': 0.008,\n",
595
+ " 'P(Sensitive Subjects)': 0.003,\n",
596
+ " 'P(Real Estate)': 0.006,\n",
597
+ " 'P(Internet_and_Telecom)': 0.004,\n",
598
+ " 'P(Sports)': 0.018}"
599
+ ]
600
+ },
601
+ "execution_count": 41,
602
+ "metadata": {},
603
+ "output_type": "execute_result"
604
  }
605
  ],
606
  "source": [
 
609
  },
610
  {
611
  "cell_type": "code",
612
+ "execution_count": 42,
613
  "metadata": {},
614
+ "outputs": [
615
+ {
616
+ "name": "stdout",
617
+ "output_type": "stream",
618
+ "text": [
619
+ "Predicted Class: Computers_and_Electronics \n",
620
+ "probabilities_scores: 1.0\n",
621
+ "\n"
622
+ ]
623
+ },
624
+ {
625
+ "data": {
626
+ "text/plain": [
627
+ "{'P(Hobbies_and_Leisure)': 0.077,\n",
628
+ " 'P(News)': 0.005,\n",
629
+ " 'P(Science)': 0.009,\n",
630
+ " 'P(Autos_and_Vehicles)': 0.077,\n",
631
+ " 'P(Health)': 0.015,\n",
632
+ " 'P(Pets_and_Animals)': 0.003,\n",
633
+ " 'P(Adult)': 0.073,\n",
634
+ " 'P(Computers_and_Electronics)': 1.0,\n",
635
+ " 'P(Online Communities)': 0.086,\n",
636
+ " 'P(Beauty_and_Fitness)': 0.022,\n",
637
+ " 'P(People_and_Society)': 0.0,\n",
638
+ " 'P(Business_and_Industrial)': 0.004,\n",
639
+ " 'P(Reference)': 0.021,\n",
640
+ " 'P(Shopping)': 0.203,\n",
641
+ " 'P(Travel_and_Transportation)': 0.003,\n",
642
+ " 'P(Food_and_Drink)': 0.241,\n",
643
+ " 'P(Law_and_Government)': 0.009,\n",
644
+ " 'P(Books_and_Literature)': 0.003,\n",
645
+ " 'P(Finance)': 0.029,\n",
646
+ " 'P(Games)': 0.195,\n",
647
+ " 'P(Home_and_Garden)': 0.044,\n",
648
+ " 'P(Jobs_and_Education)': 0.004,\n",
649
+ " 'P(Arts_and_Entertainment)': 0.013,\n",
650
+ " 'P(Sensitive Subjects)': 0.003,\n",
651
+ " 'P(Real Estate)': 0.012,\n",
652
+ " 'P(Internet_and_Telecom)': 0.004,\n",
653
+ " 'P(Sports)': 0.017}"
654
+ ]
655
+ },
656
+ "execution_count": 42,
657
+ "metadata": {},
658
+ "output_type": "execute_result"
659
+ }
660
+ ],
661
  "source": [
662
  "predict(\n",
663
  " 'razer kraken'\n",
664
  ")"
665
  ]
666
  },
667
+ {
668
+ "cell_type": "code",
669
+ "execution_count": 43,
670
+ "metadata": {},
671
+ "outputs": [
672
+ {
673
+ "name": "stdout",
674
+ "output_type": "stream",
675
+ "text": [
676
+ "Predicted Class: Online Communities \n",
677
+ "probabilities_scores: 0.9990000128746033\n",
678
+ "\n"
679
+ ]
680
+ },
681
+ {
682
+ "data": {
683
+ "text/plain": [
684
+ "{'P(Hobbies_and_Leisure)': 0.009,\n",
685
+ " 'P(News)': 0.037,\n",
686
+ " 'P(Science)': 0.014,\n",
687
+ " 'P(Autos_and_Vehicles)': 0.004,\n",
688
+ " 'P(Health)': 0.007,\n",
689
+ " 'P(Pets_and_Animals)': 0.048,\n",
690
+ " 'P(Adult)': 0.287,\n",
691
+ " 'P(Computers_and_Electronics)': 0.536,\n",
692
+ " 'P(Online Communities)': 0.999,\n",
693
+ " 'P(Beauty_and_Fitness)': 0.002,\n",
694
+ " 'P(People_and_Society)': 0.001,\n",
695
+ " 'P(Business_and_Industrial)': 0.002,\n",
696
+ " 'P(Reference)': 0.006,\n",
697
+ " 'P(Shopping)': 0.038,\n",
698
+ " 'P(Travel_and_Transportation)': 0.016,\n",
699
+ " 'P(Food_and_Drink)': 0.012,\n",
700
+ " 'P(Law_and_Government)': 0.024,\n",
701
+ " 'P(Books_and_Literature)': 0.059,\n",
702
+ " 'P(Finance)': 0.001,\n",
703
+ " 'P(Games)': 0.025,\n",
704
+ " 'P(Home_and_Garden)': 0.377,\n",
705
+ " 'P(Jobs_and_Education)': 0.018,\n",
706
+ " 'P(Arts_and_Entertainment)': 0.028,\n",
707
+ " 'P(Sensitive Subjects)': 0.072,\n",
708
+ " 'P(Real Estate)': 0.002,\n",
709
+ " 'P(Internet_and_Telecom)': 0.003,\n",
710
+ " 'P(Sports)': 0.006}"
711
+ ]
712
+ },
713
+ "execution_count": 43,
714
+ "metadata": {},
715
+ "output_type": "execute_result"
716
+ }
717
+ ],
718
+ "source": [
719
+ "predict(\"facebook\")"
720
+ ]
721
+ },
722
+ {
723
+ "cell_type": "code",
724
+ "execution_count": 44,
725
+ "metadata": {},
726
+ "outputs": [
727
+ {
728
+ "name": "stdout",
729
+ "output_type": "stream",
730
+ "text": [
731
+ "Predicted Class: Computers_and_Electronics \n",
732
+ "probabilities_scores: 1.0\n",
733
+ "\n"
734
+ ]
735
+ },
736
+ {
737
+ "data": {
738
+ "text/plain": [
739
+ "{'P(Hobbies_and_Leisure)': 0.054,\n",
740
+ " 'P(News)': 0.003,\n",
741
+ " 'P(Science)': 0.011,\n",
742
+ " 'P(Autos_and_Vehicles)': 0.122,\n",
743
+ " 'P(Health)': 0.01,\n",
744
+ " 'P(Pets_and_Animals)': 0.004,\n",
745
+ " 'P(Adult)': 0.054,\n",
746
+ " 'P(Computers_and_Electronics)': 1.0,\n",
747
+ " 'P(Online Communities)': 0.081,\n",
748
+ " 'P(Beauty_and_Fitness)': 0.016,\n",
749
+ " 'P(People_and_Society)': 0.0,\n",
750
+ " 'P(Business_and_Industrial)': 0.005,\n",
751
+ " 'P(Reference)': 0.064,\n",
752
+ " 'P(Shopping)': 0.224,\n",
753
+ " 'P(Travel_and_Transportation)': 0.006,\n",
754
+ " 'P(Food_and_Drink)': 0.172,\n",
755
+ " 'P(Law_and_Government)': 0.051,\n",
756
+ " 'P(Books_and_Literature)': 0.006,\n",
757
+ " 'P(Finance)': 0.025,\n",
758
+ " 'P(Games)': 0.138,\n",
759
+ " 'P(Home_and_Garden)': 0.03,\n",
760
+ " 'P(Jobs_and_Education)': 0.006,\n",
761
+ " 'P(Arts_and_Entertainment)': 0.008,\n",
762
+ " 'P(Sensitive Subjects)': 0.003,\n",
763
+ " 'P(Real Estate)': 0.006,\n",
764
+ " 'P(Internet_and_Telecom)': 0.004,\n",
765
+ " 'P(Sports)': 0.018}"
766
+ ]
767
+ },
768
+ "execution_count": 44,
769
+ "metadata": {},
770
+ "output_type": "execute_result"
771
+ }
772
+ ],
773
+ "source": [
774
+ "predict('apple iphone')"
775
+ ]
776
+ },
777
+ {
778
+ "cell_type": "code",
779
+ "execution_count": 45,
780
+ "metadata": {},
781
+ "outputs": [
782
+ {
783
+ "name": "stdout",
784
+ "output_type": "stream",
785
+ "text": [
786
+ "Predicted Class: Computers_and_Electronics \n",
787
+ "probabilities_scores: 1.0\n",
788
+ "\n"
789
+ ]
790
+ },
791
+ {
792
+ "data": {
793
+ "text/plain": [
794
+ "{'P(Hobbies_and_Leisure)': 0.186,\n",
795
+ " 'P(News)': 0.003,\n",
796
+ " 'P(Science)': 0.009,\n",
797
+ " 'P(Autos_and_Vehicles)': 0.512,\n",
798
+ " 'P(Health)': 0.002,\n",
799
+ " 'P(Pets_and_Animals)': 0.002,\n",
800
+ " 'P(Adult)': 0.039,\n",
801
+ " 'P(Computers_and_Electronics)': 1.0,\n",
802
+ " 'P(Online Communities)': 0.061,\n",
803
+ " 'P(Beauty_and_Fitness)': 0.003,\n",
804
+ " 'P(People_and_Society)': 0.0,\n",
805
+ " 'P(Business_and_Industrial)': 0.001,\n",
806
+ " 'P(Reference)': 0.015,\n",
807
+ " 'P(Shopping)': 0.274,\n",
808
+ " 'P(Travel_and_Transportation)': 0.002,\n",
809
+ " 'P(Food_and_Drink)': 0.009,\n",
810
+ " 'P(Law_and_Government)': 0.058,\n",
811
+ " 'P(Books_and_Literature)': 0.002,\n",
812
+ " 'P(Finance)': 0.033,\n",
813
+ " 'P(Games)': 0.151,\n",
814
+ " 'P(Home_and_Garden)': 0.027,\n",
815
+ " 'P(Jobs_and_Education)': 0.002,\n",
816
+ " 'P(Arts_and_Entertainment)': 0.005,\n",
817
+ " 'P(Sensitive Subjects)': 0.001,\n",
818
+ " 'P(Real Estate)': 0.035,\n",
819
+ " 'P(Internet_and_Telecom)': 0.001,\n",
820
+ " 'P(Sports)': 0.008}"
821
+ ]
822
+ },
823
+ "execution_count": 45,
824
+ "metadata": {},
825
+ "output_type": "execute_result"
826
+ }
827
+ ],
828
+ "source": [
829
+ "predict('best vr')"
830
+ ]
831
+ },
832
+ {
833
+ "cell_type": "code",
834
+ "execution_count": 46,
835
+ "metadata": {},
836
+ "outputs": [
837
+ {
838
+ "name": "stdout",
839
+ "output_type": "stream",
840
+ "text": [
841
+ "Predicted Class: Computers_and_Electronics \n",
842
+ "probabilities_scores: 1.0\n",
843
+ "\n"
844
+ ]
845
+ },
846
+ {
847
+ "data": {
848
+ "text/plain": [
849
+ "{'P(Hobbies_and_Leisure)': 0.186,\n",
850
+ " 'P(News)': 0.003,\n",
851
+ " 'P(Science)': 0.009,\n",
852
+ " 'P(Autos_and_Vehicles)': 0.512,\n",
853
+ " 'P(Health)': 0.002,\n",
854
+ " 'P(Pets_and_Animals)': 0.002,\n",
855
+ " 'P(Adult)': 0.039,\n",
856
+ " 'P(Computers_and_Electronics)': 1.0,\n",
857
+ " 'P(Online Communities)': 0.061,\n",
858
+ " 'P(Beauty_and_Fitness)': 0.003,\n",
859
+ " 'P(People_and_Society)': 0.0,\n",
860
+ " 'P(Business_and_Industrial)': 0.001,\n",
861
+ " 'P(Reference)': 0.015,\n",
862
+ " 'P(Shopping)': 0.274,\n",
863
+ " 'P(Travel_and_Transportation)': 0.002,\n",
864
+ " 'P(Food_and_Drink)': 0.009,\n",
865
+ " 'P(Law_and_Government)': 0.058,\n",
866
+ " 'P(Books_and_Literature)': 0.002,\n",
867
+ " 'P(Finance)': 0.033,\n",
868
+ " 'P(Games)': 0.151,\n",
869
+ " 'P(Home_and_Garden)': 0.027,\n",
870
+ " 'P(Jobs_and_Education)': 0.002,\n",
871
+ " 'P(Arts_and_Entertainment)': 0.005,\n",
872
+ " 'P(Sensitive Subjects)': 0.001,\n",
873
+ " 'P(Real Estate)': 0.035,\n",
874
+ " 'P(Internet_and_Telecom)': 0.001,\n",
875
+ " 'P(Sports)': 0.008}"
876
+ ]
877
+ },
878
+ "execution_count": 46,
879
+ "metadata": {},
880
+ "output_type": "execute_result"
881
+ }
882
+ ],
883
+ "source": [
884
+ "predict(\"best vr\")"
885
+ ]
886
+ },
887
+ {
888
+ "cell_type": "code",
889
+ "execution_count": 47,
890
+ "metadata": {},
891
+ "outputs": [
892
+ {
893
+ "name": "stdout",
894
+ "output_type": "stream",
895
+ "text": [
896
+ "Predicted Class: Adult \n",
897
+ "probabilities_scores: 0.7149999737739563\n",
898
+ "\n"
899
+ ]
900
+ },
901
+ {
902
+ "data": {
903
+ "text/plain": [
904
+ "{'P(Hobbies_and_Leisure)': 0.684,\n",
905
+ " 'P(News)': 0.009,\n",
906
+ " 'P(Science)': 0.001,\n",
907
+ " 'P(Autos_and_Vehicles)': 0.004,\n",
908
+ " 'P(Health)': 0.001,\n",
909
+ " 'P(Pets_and_Animals)': 0.0,\n",
910
+ " 'P(Adult)': 0.715,\n",
911
+ " 'P(Computers_and_Electronics)': 0.274,\n",
912
+ " 'P(Online Communities)': 0.246,\n",
913
+ " 'P(Beauty_and_Fitness)': 0.003,\n",
914
+ " 'P(People_and_Society)': 0.001,\n",
915
+ " 'P(Business_and_Industrial)': 0.0,\n",
916
+ " 'P(Reference)': 0.0,\n",
917
+ " 'P(Shopping)': 0.022,\n",
918
+ " 'P(Travel_and_Transportation)': 0.001,\n",
919
+ " 'P(Food_and_Drink)': 0.002,\n",
920
+ " 'P(Law_and_Government)': 0.021,\n",
921
+ " 'P(Books_and_Literature)': 0.007,\n",
922
+ " 'P(Finance)': 0.003,\n",
923
+ " 'P(Games)': 0.012,\n",
924
+ " 'P(Home_and_Garden)': 0.178,\n",
925
+ " 'P(Jobs_and_Education)': 0.002,\n",
926
+ " 'P(Arts_and_Entertainment)': 0.01,\n",
927
+ " 'P(Sensitive Subjects)': 0.001,\n",
928
+ " 'P(Real Estate)': 0.026,\n",
929
+ " 'P(Internet_and_Telecom)': 0.0,\n",
930
+ " 'P(Sports)': 0.02}"
931
+ ]
932
+ },
933
+ "execution_count": 47,
934
+ "metadata": {},
935
+ "output_type": "execute_result"
936
+ }
937
+ ],
938
+ "source": [
939
+ "predict(\"pa best views\")"
940
+ ]
941
+ },
942
  {
943
  "cell_type": "code",
944
  "execution_count": null,
945
  "metadata": {},
946
  "outputs": [],
947
+ "source": []
948
+ },
949
+ {
950
+ "cell_type": "code",
951
+ "execution_count": null,
952
+ "metadata": {},
953
+ "outputs": [],
954
+ "source": []
955
+ },
956
+ {
957
+ "cell_type": "code",
958
+ "execution_count": 10,
959
+ "metadata": {},
960
+ "outputs": [],
961
  "source": [
962
+ "inputs = tokenizer(\"best cat ear headphones\", return_tensors=\"pt\")\n",
963
+ "with torch.no_grad():\n",
964
+ " logits = model(**inputs).logits"
965
+ ]
966
+ },
967
+ {
968
+ "cell_type": "code",
969
+ "execution_count": 14,
970
+ "metadata": {},
971
+ "outputs": [
972
+ {
973
+ "data": {
974
+ "text/plain": [
975
+ "array([-1.353771 , -5.8301578, -4.050355 , -1.9018538, -5.129807 ,\n",
976
+ " -5.2707334, -2.696651 , 8.821061 , -2.0982835, -4.4173856,\n",
977
+ " -9.076361 , -5.888918 , -3.7155762, -1.0305756, -5.5817475,\n",
978
+ " -3.987473 , -2.4096951, -5.1136127, -3.217719 , -2.938894 ,\n",
979
+ " -3.7113686, -5.8976064, -4.788314 , -6.4181705, -3.5685277,\n",
980
+ " -4.5266075, -4.3206973], dtype=float32)"
981
+ ]
982
+ },
983
+ "execution_count": 14,
984
+ "metadata": {},
985
+ "output_type": "execute_result"
986
+ }
987
+ ],
988
+ "source": [
989
+ "l= logits.numpy()[0]\n",
990
+ "l"
991
+ ]
992
+ },
993
+ {
994
+ "cell_type": "code",
995
+ "execution_count": 18,
996
+ "metadata": {},
997
+ "outputs": [],
998
+ "source": [
999
+ "# logit2prob <- function(logit){\n",
1000
+ "# odds <- exp(logit)\n",
1001
+ "# prob <- odds / (1 + odds)\n",
1002
+ "# return(prob)\n",
1003
+ "# }\n",
1004
+ "def logit2prob(logit):\n",
1005
+ " odds =np.exp(logit)\n",
1006
+ " prob = odds / (1 + odds)\n",
1007
+ " return np.round(prob, 2)"
1008
+ ]
1009
+ },
1010
+ {
1011
+ "cell_type": "code",
1012
+ "execution_count": 17,
1013
+ "metadata": {},
1014
+ "outputs": [
1015
+ {
1016
+ "name": "stdout",
1017
+ "output_type": "stream",
1018
+ "text": [
1019
+ "0.21\n",
1020
+ "0.0\n",
1021
+ "0.02\n",
1022
+ "0.13\n",
1023
+ "0.01\n",
1024
+ "0.01\n",
1025
+ "0.06\n",
1026
+ "1.0\n",
1027
+ "0.11\n",
1028
+ "0.01\n",
1029
+ "0.0\n",
1030
+ "0.0\n",
1031
+ "0.02\n",
1032
+ "0.26\n",
1033
+ "0.0\n",
1034
+ "0.02\n",
1035
+ "0.08\n",
1036
+ "0.01\n",
1037
+ "0.04\n",
1038
+ "0.05\n",
1039
+ "0.02\n",
1040
+ "0.0\n",
1041
+ "0.01\n",
1042
+ "0.0\n",
1043
+ "0.03\n",
1044
+ "0.01\n",
1045
+ "0.01\n"
1046
+ ]
1047
+ }
1048
+ ],
1049
+ "source": [
1050
+ "for i in l:\n",
1051
+ " print(round(logit2prob(i), 2))"
1052
+ ]
1053
+ },
1054
+ {
1055
+ "cell_type": "code",
1056
+ "execution_count": 19,
1057
+ "metadata": {},
1058
+ "outputs": [
1059
+ {
1060
+ "data": {
1061
+ "text/plain": [
1062
+ "array([0.21, 0. , 0.02, 0.13, 0.01, 0.01, 0.06, 1. , 0.11, 0.01, 0. ,\n",
1063
+ " 0. , 0.02, 0.26, 0. , 0.02, 0.08, 0.01, 0.04, 0.05, 0.02, 0. ,\n",
1064
+ " 0.01, 0. , 0.03, 0.01, 0.01], dtype=float32)"
1065
+ ]
1066
+ },
1067
+ "execution_count": 19,
1068
+ "metadata": {},
1069
+ "output_type": "execute_result"
1070
+ }
1071
+ ],
1072
+ "source": [
1073
+ "logit2prob(l)"
1074
  ]
1075
  },
1076
  {