doberst commited on
Commit
3156d80
·
1 Parent(s): 050fc78

Upload generation_test_hf_script.py

Browse files
Files changed (1) hide show
  1. generation_test_hf_script.py +87 -0
generation_test_hf_script.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import torch
5
+
6
+
7
+ def load_rag_benchmark_tester_ds():
8
+
9
+ # pull 200 question rag benchmark test dataset from LLMWare HuggingFace repo
10
+ from datasets import load_dataset
11
+
12
+ ds_name = "llmware/rag_instruct_benchmark_tester"
13
+
14
+ dataset = load_dataset(ds_name)
15
+
16
+ print("update: loading test dataset - ", dataset)
17
+
18
+ test_set = []
19
+ for i, samples in enumerate(dataset["train"]):
20
+ test_set.append(samples)
21
+
22
+ # to view test set samples
23
+ # print("rag benchmark dataset test samples: ", i, samples)
24
+
25
+ return test_set
26
+
27
+
28
+ def run_test(model_name, test_ds):
29
+
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+
32
+ print("update: model will be loaded on device - ", device)
33
+
34
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
35
+ model.to(device)
36
+
37
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
38
+
39
+ for i, entries in enumerate(test_ds):
40
+
41
+ # prepare prompt packaging used in fine-tuning process
42
+ new_prompt = "<human>: " + entries["context"] + "\n" + entries["query"] + "\n" + "<bot>:"
43
+
44
+ inputs = tokenizer(new_prompt, return_tensors="pt")
45
+ start_of_output = len(inputs.input_ids[0])
46
+
47
+ # temperature: set at 0.3 for consistency of output
48
+ # max_new_tokens: set at 100 - may prematurely stop a few of the summaries
49
+
50
+ outputs = model.generate(
51
+ inputs.input_ids.to(device),
52
+ eos_token_id=tokenizer.eos_token_id,
53
+ pad_token_id=tokenizer.eos_token_id,
54
+ do_sample=True,
55
+ temperature=0.3,
56
+ max_new_tokens=100,
57
+ )
58
+
59
+ output_only = tokenizer.decode(outputs[0][start_of_output:],skip_special_tokens=True)
60
+
61
+ # quick/optional post-processing clean-up of potential fine-tuning artifacts
62
+
63
+ eot = output_only.find("<|endoftext|>")
64
+ if eot > -1:
65
+ output_only = output_only[:eot]
66
+
67
+ bot = output_only.find("<bot>:")
68
+ if bot > -1:
69
+ output_only = output_only[bot+len("<bot>:"):]
70
+
71
+ # end - post-processing
72
+
73
+ print("\n")
74
+ print(i, "llm_response - ", output_only)
75
+ print(i, "gold_answer - ", entries["answer"])
76
+
77
+ return 0
78
+
79
+
80
+ if __name__ == "__main__":
81
+
82
+ test_ds = load_rag_benchmark_tester_ds()
83
+
84
+ model_name = "llmware/dragon-mistral-7b-v0"
85
+ output = run_test(model_name,test_ds)
86
+
87
+