ai-forever
commited on
Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
---
|
3 |
+
# Model Card for FRIDA
|
4 |
+
|
5 |
+
The FRIDA is a general text embedding model for Russian. The model is based on the encoder part of FRED-T5 (https://huggingface.co/ai-forever/FRED-T5-1.7B). It has been pre-trained on a Russian-English dataset and fine-tuned for improved performance on the target task.
|
6 |
+
|
7 |
+
For more model details please refer to our [article](TODO).
|
8 |
+
|
9 |
+
## Usage
|
10 |
+
|
11 |
+
The model can be used as is with prefixes. It is recommended to use CLS pooling. The choice of prefix and pooling depends on the task.
|
12 |
+
|
13 |
+
We use the following basic rules to choose a prefix:
|
14 |
+
- `"search_query: "` and `"search_document: "` prefixes are for answer or relevant paragraph retrieval
|
15 |
+
- `"paraphrase: "` prefix is for symmetric paraphrasing related tasks (STS, paraphrase mining, deduplication)
|
16 |
+
- `"categorize: "` prefix is for asymmetric matching of document title and body (e.g. news, scientific papers, social posts)
|
17 |
+
- `"categorize_sentiment: "` prefix is for any tasks that rely on sentiment features (e.g. hate, toxic, emotion)
|
18 |
+
- `"categorize_topic: "` prefix is intended for tasks where you need to group texts by topic
|
19 |
+
- `"categorize_entailment: "` prefix is for textual entailment task (NLI)
|
20 |
+
|
21 |
+
To better tailor the model to your needs, you can fine-tune it with relevant high-quality Russian and English datasets.
|
22 |
+
|
23 |
+
Below are examples of texts encoding using the Transformers and SentenceTransformers libraries.
|
24 |
+
|
25 |
+
### Transformers
|
26 |
+
|
27 |
+
```python
|
28 |
+
import torch
|
29 |
+
import torch.nn.functional as F
|
30 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
31 |
+
|
32 |
+
|
33 |
+
def pool(hidden_state, mask, pooling_method="cls"):
|
34 |
+
if pooling_method == "mean":
|
35 |
+
s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
|
36 |
+
d = mask.sum(axis=1, keepdim=True).float()
|
37 |
+
return s / d
|
38 |
+
elif pooling_method == "cls":
|
39 |
+
return hidden_state[:, 0]
|
40 |
+
|
41 |
+
inputs = [
|
42 |
+
#
|
43 |
+
"paraphrase: Он нам и <unk> не нужон ваш Интернет!",
|
44 |
+
"categorize_entailment: В Ярославской области разрешили работу бань, но без посетителей",
|
45 |
+
"search_query: Сколько программистов нужно, чтобы вкрутить лампочку?",
|
46 |
+
#
|
47 |
+
"paraphrase: What a time to be alive!",
|
48 |
+
"categorize_entailment: Ярославским баням разрешили работать без посетителей",
|
49 |
+
"search_document: Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование.",
|
50 |
+
]
|
51 |
+
|
52 |
+
tokenizer = AutoTokenizer.from_pretrained("ai-forever/FRIDA")
|
53 |
+
model = T5EncoderModel.from_pretrained("ai-forever/FRIDA")
|
54 |
+
|
55 |
+
tokenized_inputs = tokenizer(inputs, max_length=512, padding=True, truncation=True, return_tensors="pt")
|
56 |
+
|
57 |
+
with torch.no_grad():
|
58 |
+
outputs = model(**tokenized_inputs)
|
59 |
+
|
60 |
+
embeddings = pool(
|
61 |
+
outputs.last_hidden_state,
|
62 |
+
tokenized_inputs["attention_mask"],
|
63 |
+
pooling_method="cls" # or try "mean"
|
64 |
+
)
|
65 |
+
|
66 |
+
embeddings = F.normalize(embeddings, p=2, dim=1)
|
67 |
+
sim_scores = embeddings[:3] @ embeddings[3:].T
|
68 |
+
print(sim_scores.diag().tolist())
|
69 |
+
# [0.4796873927116394, 0.9409002065658569, 0.7761015892028809]
|
70 |
+
```
|
71 |
+
|
72 |
+
### SentenceTransformers
|
73 |
+
|
74 |
+
```python
|
75 |
+
from sentence_transformers import SentenceTransformer
|
76 |
+
|
77 |
+
inputs = [
|
78 |
+
#
|
79 |
+
"paraphrase: Он нам и <unk> не нужон ваш Интернет!",
|
80 |
+
"categorize_entailment: В Ярославской области разрешили работу бань, но без посетителей",
|
81 |
+
"search_query: Сколько программистов нужно, чтобы вкрутить лампочку?",
|
82 |
+
#
|
83 |
+
"paraphrase: What a time to be alive!",
|
84 |
+
"categorize_entailment: Ярославским баням разрешили работать без посетителей",
|
85 |
+
"search_document: Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование.",
|
86 |
+
]
|
87 |
+
|
88 |
+
# loads model with CLS pooling
|
89 |
+
model = SentenceTransformer("ai-forever/FRIDA")
|
90 |
+
|
91 |
+
# embeddings are normalized by default
|
92 |
+
embeddings = model.encode(inputs, convert_to_tensor=True)
|
93 |
+
|
94 |
+
sim_scores = embeddings[:3] @ embeddings[3:].T
|
95 |
+
print(sim_scores.diag().tolist())
|
96 |
+
# [0.47968706488609314, 0.940900444984436, 0.7761018872261047]
|
97 |
+
```
|
98 |
+
|
99 |
+
or using prompts (sentence-transformers>=2.4.0):
|
100 |
+
|
101 |
+
```python
|
102 |
+
from sentence_transformers import SentenceTransformer
|
103 |
+
|
104 |
+
# loads model with CLS pooling
|
105 |
+
model = SentenceTransformer("ai-forever/FRIDA")
|
106 |
+
|
107 |
+
classification = model.encode(["Он нам и <unk> не нужон ваш Интернет!", "What a time to be alive!"], prompt_name="paraphrase")
|
108 |
+
print(classification[0] @ classification[1].T) # 0.47968706488609314
|
109 |
+
|
110 |
+
clustering = model.encode(["В Ярославской области разрешили работу бань, но без посетителей", "Ярославским баням разрешили работать без посетителей"], prompt_name="categorize_entailment")
|
111 |
+
print(clustering[0] @ clustering[1].T) # 0.940900444984436
|
112 |
+
|
113 |
+
query_embedding = model.encode("Сколько программистов нужно, чтобы вкрутить лампочку?", prompt_name="search_query")
|
114 |
+
document_embedding = model.encode("Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование.", prompt_name="search_document")
|
115 |
+
print(query_embedding @ document_embedding.T) # 0.7761018872261047
|
116 |
+
```
|
117 |
+
|
118 |
+
## Citation
|
119 |
+
|
120 |
+
```
|
121 |
+
@misc{TODO
|
122 |
+
}
|
123 |
+
```
|
124 |
+
|
125 |
+
## Limitations
|
126 |
+
|
127 |
+
The model is designed to process texts in Russian, the quality in English is unknown. Maximum input text length is limited to 512 tokens.
|