orionweller
commited on
Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
tags:
|
6 |
+
- retrieval
|
7 |
+
- instructions
|
8 |
+
- reranking
|
9 |
+
---
|
10 |
+
|
11 |
+
# Model Summary
|
12 |
+
|
13 |
+
> FollowIR-7B is an instruction-tuned language model to be used for reranking in retrieval. It is Mistral-7B-Instruct-v0.2 fine-tuned on instruction retrieval data.
|
14 |
+
|
15 |
+
- **Repository:** [ContextualAI/gritlm](https://github.com/ContextualAI/gritlm)
|
16 |
+
- **Paper:** https://arxiv.org/abs/2402.09906
|
17 |
+
- **Logs:** https://wandb.ai/muennighoff/gritlm/runs/0uui712t/overview
|
18 |
+
- **Script:** https://github.com/ContextualAI/gritlm/blob/main/scripts/training/train_gritlm_7b.sh
|
19 |
+
|
20 |
+
| Model | Description |
|
21 |
+
|-------|-------------|
|
22 |
+
| [GritLM 7B](https://hf.co/GritLM/GritLM-7B) | Mistral 7B finetuned using GRIT |
|
23 |
+
| [GritLM 8x7B](https://hf.co/GritLM/GritLM-8x7B) | Mixtral 8x7B finetuned using GRIT |
|
24 |
+
|
25 |
+
# Use
|
26 |
+
|
27 |
+
Below is an example to compute the similarity score of a query-document pair
|
28 |
+
```python
|
29 |
+
model_name = "jhu-clsp/FollowIR-7B"
|
30 |
+
model = AutoModelForCausalLM.from_pretrained(
|
31 |
+
model_name
|
32 |
+
).cuda()
|
33 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
34 |
+
model_name, padding_side="left"
|
35 |
+
)
|
36 |
+
tokenizer.pad_token = tokenizer.eos_token
|
37 |
+
tokenizer.padding_side = "left"
|
38 |
+
token_false_id = tokenizer.get_vocab()["false"]
|
39 |
+
token_true_id = tokenizer.get_vocab()["true"]
|
40 |
+
max_length = min(2048, tokenizer.model_max_length)
|
41 |
+
|
42 |
+
template = """<s> [INST] You are an expert Google searcher, whose job is to determine if the following document is relevant to the query (true/false). Answer using only one word, one of those two choices.
|
43 |
+
|
44 |
+
Query: {query}
|
45 |
+
Document: {text}
|
46 |
+
Relevant (only output one word, either "true" or "false"): [/INST] """
|
47 |
+
|
48 |
+
|
49 |
+
prompts = [
|
50 |
+
template.format(query=query, text=text) for (query, text) in zip([query] * 2, passages)
|
51 |
+
]
|
52 |
+
tokens = tokenizer(
|
53 |
+
prompts,
|
54 |
+
padding=True,
|
55 |
+
truncation=True,
|
56 |
+
return_tensors="pt",
|
57 |
+
max_length=max_length,
|
58 |
+
pad_to_multiple_of=None,
|
59 |
+
)
|
60 |
+
|
61 |
+
if "token_type_ids" in tokens:
|
62 |
+
del tokens["token_type_ids"]
|
63 |
+
|
64 |
+
# move to cuda if desired
|
65 |
+
for key in tokens:
|
66 |
+
tokens[key] = tokens[key].cuda()
|
67 |
+
|
68 |
+
# calculate the scores
|
69 |
+
batch_scores = model(**tokens).logits[:, -1, :]
|
70 |
+
true_vector = batch_scores[:, token_true_id]
|
71 |
+
false_vector = batch_scores[:, token_false_id]
|
72 |
+
batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
73 |
+
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
|
74 |
+
scores = batch_scores[:, 1].exp().tolist()
|
75 |
+
```
|
76 |
+
|
77 |
+
# Citation
|
78 |
+
|
79 |
+
```bibtex
|
80 |
+
TODO
|
81 |
+
```
|