Update README.md
Browse files
README.md
CHANGED
@@ -98,41 +98,33 @@ from transformers import GenerationConfig, AutoTokenizer
|
|
98 |
import torch
|
99 |
|
100 |
model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GPTQ"
|
101 |
-
adapter="nmarafo/Mistral-7B-Instruct-v0.2-TrueFalse-Feedback-GPTQ"
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
adapter,
|
106 |
-
low_cpu_mem_usage=True,
|
107 |
-
return_dict=True,
|
108 |
-
torch_dtype=torch.float16,
|
109 |
-
device_map="cuda")
|
110 |
|
111 |
-
|
112 |
-
system_message = "Analyze the question, the expected answer, and the student's response. Determine if the student's answer is conceptually correct in relation to the expected answer, regardless of the exact wording. An answer will be considered correct if it accurately identifies the key information requested in the question, even if expressed differently. Return True if the student's answer is correct or False otherwise. Add a brief comment explaining the rationale behind the answer being correct or incorrect."
|
113 |
-
question = data_point["question"][0]
|
114 |
-
best_answer = data_point["best_answer"][0]
|
115 |
-
student_answer = data_point["student_answer"][0]
|
116 |
-
prompt = f"{system_message}\n\nQuestion: {question}\nExpected Answer: {best_answer}\nStudent Answer: {student_answer}"
|
117 |
|
118 |
-
|
|
|
|
|
|
|
119 |
|
120 |
-
tokenizer =
|
121 |
-
|
122 |
-
|
123 |
-
return_token_type_ids=False)
|
124 |
-
tokenizer.pad_token = tokenizer.eos_token
|
125 |
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
best_answer="Tenerife, Fuerteventura, Gran Canaria, Lanzarote, La Palma, La Gomera, El Hierro, La Graciosa"
|
128 |
student_answer="Tenerife"
|
129 |
|
130 |
-
|
131 |
-
prompt_template=f'''<s>[INST] {prompt} [/INST]'''
|
132 |
-
|
133 |
-
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.cuda()
|
134 |
-
output = persisted_model.generate(inputs=input_ids, temperature=0.7, do_sample=True, top_p=0.95, top_k=40, max_new_tokens=512)
|
135 |
-
print(tokenizer.decode(output[0]))
|
136 |
|
137 |
```
|
138 |
|
|
|
98 |
import torch
|
99 |
|
100 |
model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GPTQ"
|
101 |
+
adapter = "nmarafo/Mistral-7B-Instruct-v0.2-TrueFalse-Feedback-GPTQ"
|
102 |
|
103 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, return_token_type_ids=False)
|
104 |
+
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
+
model = AutoPeftModelForCausalLM.from_pretrained(adapter, low_cpu_mem_usage=True, return_dict=True, torch_dtype=torch.float16, device_map="cuda")
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
+
def predict(question, best_answer, student_answer):
|
109 |
+
system_message = "Analyze the question, the expected answer, and the student's response. Determine if the student's answer is conceptually correct in relation to the expected answer, regardless of the exact wording. Return True if the student's answer is correct or False otherwise. Add a brief comment explaining the rationale behind the answer being correct or incorrect."
|
110 |
+
prompt = f"{system_message}\n\nQuestion: {question}\nBest Answer: {best_answer}\nStudent Answer: {student_answer}"
|
111 |
+
prompt_template=f"<s>[INST]{prompt}[/INST]"
|
112 |
|
113 |
+
encoding = tokenizer(prompt_template, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
114 |
+
input_ids = encoding['input_ids'].cuda()
|
115 |
+
attention_mask = encoding['attention_mask'].cuda()
|
|
|
|
|
116 |
|
117 |
+
output = model.generate(input_ids, attention_mask=attention_mask,
|
118 |
+
temperature=0.7, do_sample=True, top_p=0.95,
|
119 |
+
top_k=40, max_new_tokens=512, pad_token_id=tokenizer.eos_token_id)
|
120 |
+
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
121 |
+
return response
|
122 |
+
|
123 |
+
question="Mention all the Canary Island"
|
124 |
best_answer="Tenerife, Fuerteventura, Gran Canaria, Lanzarote, La Palma, La Gomera, El Hierro, La Graciosa"
|
125 |
student_answer="Tenerife"
|
126 |
|
127 |
+
print(predict(question, best_answer, student_answer))
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
```
|
130 |
|