Update README.md
Browse files
README.md
CHANGED
@@ -19,7 +19,7 @@ Determine if the student's answer is correct or not. It only returns True if the
|
|
19 |
Add a brief comment explaining why the answer is correct or incorrect.\n\n
|
20 |
Question: {question}\n
|
21 |
Expected Answer: {best_answer}\n
|
22 |
-
Student Answer: {student_answer}[/INST]
|
23 |
```
|
24 |
|
25 |
|
@@ -84,8 +84,63 @@ Student Answer: {student_answer}[/INST]</s>"
|
|
84 |
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
85 |
|
86 |
## How to Get Started with the Model
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
[More Information Needed]
|
91 |
|
|
|
19 |
Add a brief comment explaining why the answer is correct or incorrect.\n\n
|
20 |
Question: {question}\n
|
21 |
Expected Answer: {best_answer}\n
|
22 |
+
Student Answer: {student_answer}[/INST]"
|
23 |
```
|
24 |
|
25 |
|
|
|
84 |
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
85 |
|
86 |
## How to Get Started with the Model
|
87 |
+
In Google Colab:
|
88 |
+
'''
|
89 |
+
!pip install -q -U transformers peft accelerate optimum
|
90 |
+
!pip install datasets==2.15.0
|
91 |
+
!pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu117/
|
92 |
+
|
93 |
+
from peft import AutoPeftModelForCausalLM
|
94 |
+
from rich import print
|
95 |
+
from transformers import GenerationConfig, AutoTokenizer
|
96 |
+
|
97 |
+
import torch
|
98 |
+
|
99 |
+
model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GPTQ"
|
100 |
+
adapter="nmarafo/Mistral-7B-Instruct-v0.2-TrueFalse-Feedback-GPTQ"
|
101 |
+
|
102 |
+
def generate_prompt(data_point):
|
103 |
+
system_message = "Analyze the question, the expected answer, and the student's response. Determine if the student's answer is conceptually correct in relation to the expected answer, regardless of the exact wording. An answer will be considered correct if it accurately identifies the key information requested in the question, even if expressed differently. Return True if the student's answer is correct or False otherwise. Add a brief comment explaining the rationale behind the answer being correct or incorrect."
|
104 |
+
question = data_point["question"][0]
|
105 |
+
best_answer = data_point["best_answer"][0]
|
106 |
+
student_answer = data_point["student_answer"][0]
|
107 |
+
prompt = f"{system_message}\n\nQuestion: {question}\nExpected Answer: {best_answer}\nStudent Answer: {student_answer}"
|
108 |
+
|
109 |
+
return prompt
|
110 |
+
|
111 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
112 |
+
model_id,
|
113 |
+
trust_remote_code=True,
|
114 |
+
return_token_type_ids=False)
|
115 |
+
tokenizer.pad_token = tokenizer.eos_token
|
116 |
+
|
117 |
+
question="Name of Canary Island"
|
118 |
+
best_answer="Tenerife, Fuerteventura, Gran Canaria, Lanzarote, La Palma, La Gomera, El Hierro, La Graciosa"
|
119 |
+
student_answer="Tenerife"
|
120 |
+
|
121 |
+
prompt = generate_prompt([{"question":question, "best_answer":best_answer,"student_answer":student_answer}])
|
122 |
+
prompt_template=f'''<s>[INST] {prompt} [/INST]'''
|
123 |
+
|
124 |
+
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.cuda()
|
125 |
+
output = persisted_model.generate(inputs=input_ids, temperature=0.7, do_sample=True, top_p=0.95, top_k=40, max_new_tokens=512)
|
126 |
+
print(tokenizer.decode(output[0]))
|
127 |
+
|
128 |
+
# To perform inference on the test dataset example load the model from the checkpoint
|
129 |
+
persisted_model = AutoPeftModelForCausalLM.from_pretrained(
|
130 |
+
adapter,
|
131 |
+
low_cpu_mem_usage=True,
|
132 |
+
return_dict=True,
|
133 |
+
torch_dtype=torch.float16,
|
134 |
+
device_map="cuda")
|
135 |
+
# Some gen config knobs
|
136 |
+
generation_config = GenerationConfig(
|
137 |
+
penalty_alpha=0.6,
|
138 |
+
do_sample = True,
|
139 |
+
top_k=5,
|
140 |
+
temperature=0.5,
|
141 |
+
repetition_penalty=1.2,
|
142 |
+
max_new_tokens=512
|
143 |
+
)
|
144 |
|
145 |
[More Information Needed]
|
146 |
|