pixas commited on
Commit
fa36636
·
verified ·
1 Parent(s): 2e5c8d0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +85 -3
README.md CHANGED
@@ -1,3 +1,85 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ base_model:
6
+ - meta-llama/Llama-3.1-8B-Instruct
7
+ pipeline_tag: token-classification
8
+ ---
9
+
10
+ <div align="center">
11
+ <h1>
12
+ MedSSS-8B-PRM
13
+ </h1>
14
+ </div>
15
+
16
+ <div align="center">
17
+ <a href="https://github.com/pixas/MedSSS" target="_blank">GitHub</a> | <a href="" target="_blank">Paper</a>
18
+ </div>
19
+
20
+ # <span>Introduction</span>
21
+ **MedSSS-PRM** is a the PRM model designed for slow-thinking medical reasoning. It will assign a `[0-1]` float value for every internal reasoning step of **MedSSS-Policy**.
22
+
23
+ For more information, visit our GitHub repository:
24
+ [https://github.com/pixas/MedSSS](https://github.com/pixas/MedSSS).
25
+
26
+
27
+
28
+
29
+ # <span>Usage</span>
30
+ We build the PRM model as a LoRA adapter, which saves the memory to use it.
31
+ As this LoRA adapter is built on `Meta-Llama3.1-8B-Instruct`, you need to first prepare the base model in your platform.
32
+
33
+ ```python
34
+
35
+ def obtain_prm_value_for_single_pair(tokenizer, value_model, inputs, outputs):
36
+ # `outputs` generated by the MedSSS-Policy
37
+ response = outputs
38
+ completions = [f"Step" + completion if not completion.startswith("Step") else completion for k, completion in enumerate(outputs.split("\n\nStep"))]
39
+
40
+ messages = [
41
+ {"role": "user", "content": inputs},
42
+ {"role": "assistant", "content": response}
43
+ ]
44
+ input_text = tokenizer.apply_chat_template(messages, tokenize=False)
45
+
46
+ response_begin_index = input_text.index(response)
47
+
48
+ pre_response_input = input_text[:response_begin_index]
49
+ after_response_input = input_text[response_begin_index + len(response):]
50
+ completion_ids = [
51
+ tokenizer(completion + "\n\n", add_special_tokens=False)['input_ids'] for completion in completions
52
+ ]
53
+
54
+ response_id = list(chain(*completion_ids))
55
+ pre_response_id = tokenizer(pre_response_input, add_special_tokens=False)['input_ids']
56
+ after_response_id = tokenizer(after_response_input, add_special_tokens=False)['input_ids']
57
+
58
+
59
+ input_ids = pre_response_id + response_id + after_response_id
60
+
61
+ value = value_model(input_ids=torch.tensor(input_ids).unsqueeze(0).to(value_model.device)) # [1, N]
62
+
63
+ completion_index = []
64
+ for i, completion in enumerate(completion_ids):
65
+ if i == 0:
66
+ completion_index.append(len(completion) + len(pre_response_id) - 1)
67
+ else:
68
+ completion_index.append(completion_index[-1] + len(completion))
69
+
70
+ step_value = value[0, completion_index].cpu().numpy().tolist()
71
+ return step_value
72
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
73
+ from peft import PeftModel
74
+ base_model = AutoModelForTokenClassification.from_pretrained("meta-llama/Llama-3.1-8B-Instruct",torch_dtype="auto",device_map="auto")
75
+ model = PeftModel.from_pretrained(base_model, "pixas/MedSSS_PRM", torc_dtype="auto", device_map="auto")
76
+ tokenizer = AutoTokenizer.from_pretrained("pixas/MedSSS_PRM")
77
+ steps
78
+ input_text = "How to stop a cough?"
79
+ step_wise_generation = "Step 0: Let's break down this problem step by step.\n\nStep 1: First [omitted]"
80
+
81
+ value = obtain_prm_value_for_single_pair(tokenizer, model, input_text, step_wise_generation)
82
+ print(value)
83
+ ```
84
+
85
+ MedSSS-PRM uses "\n\nStep" to separate intermediate steps. So the token classification happens before the next "Step k: " or the end of the sequence.