ai-forever commited on
Commit
36d796b
·
verified ·
1 Parent(s): 537c839

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +127 -0
README.md ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ---
3
+ # Model Card for FRIDA
4
+
5
+ The FRIDA is a general text embedding model for Russian. The model is based on the encoder part of FRED-T5 (https://huggingface.co/ai-forever/FRED-T5-1.7B). It has been pre-trained on a Russian-English dataset and fine-tuned for improved performance on the target task.
6
+
7
+ For more model details please refer to our [article](TODO).
8
+
9
+ ## Usage
10
+
11
+ The model can be used as is with prefixes. It is recommended to use CLS pooling. The choice of prefix and pooling depends on the task.
12
+
13
+ We use the following basic rules to choose a prefix:
14
+ - `"search_query: "` and `"search_document: "` prefixes are for answer or relevant paragraph retrieval
15
+ - `"paraphrase: "` prefix is for symmetric paraphrasing related tasks (STS, paraphrase mining, deduplication)
16
+ - `"categorize: "` prefix is for asymmetric matching of document title and body (e.g. news, scientific papers, social posts)
17
+ - `"categorize_sentiment: "` prefix is for any tasks that rely on sentiment features (e.g. hate, toxic, emotion)
18
+ - `"categorize_topic: "` prefix is ​​intended for tasks where you need to group texts by topic
19
+ - `"categorize_entailment: "` prefix is for textual entailment task (NLI)
20
+
21
+ To better tailor the model to your needs, you can fine-tune it with relevant high-quality Russian and English datasets.
22
+
23
+ Below are examples of texts encoding using the Transformers and SentenceTransformers libraries.
24
+
25
+ ### Transformers
26
+
27
+ ```python
28
+ import torch
29
+ import torch.nn.functional as F
30
+ from transformers import AutoTokenizer, T5EncoderModel
31
+
32
+
33
+ def pool(hidden_state, mask, pooling_method="cls"):
34
+ if pooling_method == "mean":
35
+ s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
36
+ d = mask.sum(axis=1, keepdim=True).float()
37
+ return s / d
38
+ elif pooling_method == "cls":
39
+ return hidden_state[:, 0]
40
+
41
+ inputs = [
42
+ #
43
+ "paraphrase: Он нам и <unk> не нужон ваш Интернет!",
44
+ "categorize_entailment: В Ярославской области разрешили работу бань, но без посетителей",
45
+ "search_query: Сколько программистов нужно, чтобы вкрутить лампочку?",
46
+ #
47
+ "paraphrase: What a time to be alive!",
48
+ "categorize_entailment: Ярославским баням разрешили работать без посетителей",
49
+ "search_document: Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование.",
50
+ ]
51
+
52
+ tokenizer = AutoTokenizer.from_pretrained("ai-forever/FRIDA")
53
+ model = T5EncoderModel.from_pretrained("ai-forever/FRIDA")
54
+
55
+ tokenized_inputs = tokenizer(inputs, max_length=512, padding=True, truncation=True, return_tensors="pt")
56
+
57
+ with torch.no_grad():
58
+ outputs = model(**tokenized_inputs)
59
+
60
+ embeddings = pool(
61
+ outputs.last_hidden_state,
62
+ tokenized_inputs["attention_mask"],
63
+ pooling_method="cls" # or try "mean"
64
+ )
65
+
66
+ embeddings = F.normalize(embeddings, p=2, dim=1)
67
+ sim_scores = embeddings[:3] @ embeddings[3:].T
68
+ print(sim_scores.diag().tolist())
69
+ # [0.4796873927116394, 0.9409002065658569, 0.7761015892028809]
70
+ ```
71
+
72
+ ### SentenceTransformers
73
+
74
+ ```python
75
+ from sentence_transformers import SentenceTransformer
76
+
77
+ inputs = [
78
+ #
79
+ "paraphrase: Он нам и <unk> не нужон ваш Интернет!",
80
+ "categorize_entailment: В Ярославской области разрешили работу бань, но без посетителей",
81
+ "search_query: Сколько программистов нужно, чтобы вкрутить лампочку?",
82
+ #
83
+ "paraphrase: What a time to be alive!",
84
+ "categorize_entailment: Ярославским баням разрешили работать без посетителей",
85
+ "search_document: Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование.",
86
+ ]
87
+
88
+ # loads model with CLS pooling
89
+ model = SentenceTransformer("ai-forever/FRIDA")
90
+
91
+ # embeddings are normalized by default
92
+ embeddings = model.encode(inputs, convert_to_tensor=True)
93
+
94
+ sim_scores = embeddings[:3] @ embeddings[3:].T
95
+ print(sim_scores.diag().tolist())
96
+ # [0.47968706488609314, 0.940900444984436, 0.7761018872261047]
97
+ ```
98
+
99
+ or using prompts (sentence-transformers>=2.4.0):
100
+
101
+ ```python
102
+ from sentence_transformers import SentenceTransformer
103
+
104
+ # loads model with CLS pooling
105
+ model = SentenceTransformer("ai-forever/FRIDA")
106
+
107
+ classification = model.encode(["Он нам и <unk> не нужон ваш Интернет!", "What a time to be alive!"], prompt_name="paraphrase")
108
+ print(classification[0] @ classification[1].T) # 0.47968706488609314
109
+
110
+ clustering = model.encode(["В Ярославской области разрешили работу бань, но без посетителей", "Ярославским баням разрешили работать без посетителей"], prompt_name="categorize_entailment")
111
+ print(clustering[0] @ clustering[1].T) # 0.940900444984436
112
+
113
+ query_embedding = model.encode("Сколько программистов нужно, чтобы вкрутить лампочку?", prompt_name="search_query")
114
+ document_embedding = model.encode("Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование.", prompt_name="search_document")
115
+ print(query_embedding @ document_embedding.T) # 0.7761018872261047
116
+ ```
117
+
118
+ ## Citation
119
+
120
+ ```
121
+ @misc{TODO
122
+ }
123
+ ```
124
+
125
+ ## Limitations
126
+
127
+ The model is designed to process texts in Russian, the quality in English is unknown. Maximum input text length is limited to 512 tokens.