nmarafo commited on
Commit
99fe6f7
·
verified ·
1 Parent(s): 278113f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +58 -3
README.md CHANGED
@@ -19,7 +19,7 @@ Determine if the student's answer is correct or not. It only returns True if the
19
  Add a brief comment explaining why the answer is correct or incorrect.\n\n
20
  Question: {question}\n
21
  Expected Answer: {best_answer}\n
22
- Student Answer: {student_answer}[/INST]</s>"
23
  ```
24
 
25
 
@@ -84,8 +84,63 @@ Student Answer: {student_answer}[/INST]</s>"
84
  Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
85
 
86
  ## How to Get Started with the Model
87
-
88
- Use the code below to get started with the model.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  [More Information Needed]
91
 
 
19
  Add a brief comment explaining why the answer is correct or incorrect.\n\n
20
  Question: {question}\n
21
  Expected Answer: {best_answer}\n
22
+ Student Answer: {student_answer}[/INST]"
23
  ```
24
 
25
 
 
84
  Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
85
 
86
  ## How to Get Started with the Model
87
+ In Google Colab:
88
+ '''
89
+ !pip install -q -U transformers peft accelerate optimum
90
+ !pip install datasets==2.15.0
91
+ !pip install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu117/
92
+
93
+ from peft import AutoPeftModelForCausalLM
94
+ from rich import print
95
+ from transformers import GenerationConfig, AutoTokenizer
96
+
97
+ import torch
98
+
99
+ model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GPTQ"
100
+ adapter="nmarafo/Mistral-7B-Instruct-v0.2-TrueFalse-Feedback-GPTQ"
101
+
102
+ def generate_prompt(data_point):
103
+ system_message = "Analyze the question, the expected answer, and the student's response. Determine if the student's answer is conceptually correct in relation to the expected answer, regardless of the exact wording. An answer will be considered correct if it accurately identifies the key information requested in the question, even if expressed differently. Return True if the student's answer is correct or False otherwise. Add a brief comment explaining the rationale behind the answer being correct or incorrect."
104
+ question = data_point["question"][0]
105
+ best_answer = data_point["best_answer"][0]
106
+ student_answer = data_point["student_answer"][0]
107
+ prompt = f"{system_message}\n\nQuestion: {question}\nExpected Answer: {best_answer}\nStudent Answer: {student_answer}"
108
+
109
+ return prompt
110
+
111
+ tokenizer = AutoTokenizer.from_pretrained(
112
+ model_id,
113
+ trust_remote_code=True,
114
+ return_token_type_ids=False)
115
+ tokenizer.pad_token = tokenizer.eos_token
116
+
117
+ question="Name of Canary Island"
118
+ best_answer="Tenerife, Fuerteventura, Gran Canaria, Lanzarote, La Palma, La Gomera, El Hierro, La Graciosa"
119
+ student_answer="Tenerife"
120
+
121
+ prompt = generate_prompt([{"question":question, "best_answer":best_answer,"student_answer":student_answer}])
122
+ prompt_template=f'''<s>[INST] {prompt} [/INST]'''
123
+
124
+ input_ids = tokenizer(prompt, return_tensors='pt').input_ids.cuda()
125
+ output = persisted_model.generate(inputs=input_ids, temperature=0.7, do_sample=True, top_p=0.95, top_k=40, max_new_tokens=512)
126
+ print(tokenizer.decode(output[0]))
127
+
128
+ # To perform inference on the test dataset example load the model from the checkpoint
129
+ persisted_model = AutoPeftModelForCausalLM.from_pretrained(
130
+ adapter,
131
+ low_cpu_mem_usage=True,
132
+ return_dict=True,
133
+ torch_dtype=torch.float16,
134
+ device_map="cuda")
135
+ # Some gen config knobs
136
+ generation_config = GenerationConfig(
137
+ penalty_alpha=0.6,
138
+ do_sample = True,
139
+ top_k=5,
140
+ temperature=0.5,
141
+ repetition_penalty=1.2,
142
+ max_new_tokens=512
143
+ )
144
 
145
  [More Information Needed]
146