Transformers
fastai
English
Inference Endpoints
sapphomoon commited on
Commit
afb83aa
·
1 Parent(s): ce3fe82

Upload truthful_qa.py

Browse files
Files changed (1) hide show
  1. truthful_qa.py +164 -0
truthful_qa.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """TruthfulQA dataset."""
15
+
16
+
17
+ import csv
18
+ import json
19
+
20
+ import datasets
21
+
22
+
23
+ _CITATION = """\
24
+ @misc{lin2021truthfulqa,
25
+ title={TruthfulQA: Measuring How Models Mimic Human Falsehoods},
26
+ author={Stephanie Lin and Jacob Hilton and Owain Evans},
27
+ year={2021},
28
+ eprint={2109.07958},
29
+ archivePrefix={arXiv},
30
+ primaryClass={cs.CL}
31
+ }
32
+ """
33
+
34
+ _DESCRIPTION = """\
35
+ TruthfulQA is a benchmark to measure whether a language model is truthful in
36
+ generating answers to questions. The benchmark comprises 817 questions that
37
+ span 38 categories, including health, law, finance and politics. Questions are
38
+ crafted so that some humans would answer falsely due to a false belief or
39
+ misconception. To perform well, models must avoid generating false answers
40
+ learned from imitating human texts.
41
+ """
42
+
43
+ _HOMEPAGE = "https://github.com/sylinrl/TruthfulQA"
44
+
45
+ _LICENSE = "Apache License 2.0"
46
+
47
+
48
+ class TruthfulQaConfig(datasets.BuilderConfig):
49
+ """BuilderConfig for TruthfulQA."""
50
+
51
+ def __init__(self, url, features, **kwargs):
52
+ """BuilderConfig for TruthfulQA.
53
+ Args:
54
+ url: *string*, the url to the configuration's data.
55
+ features: *list[string]*, list of features that'll appear in the feature dict.
56
+ **kwargs: keyword arguments forwarded to super.
57
+ """
58
+ super().__init__(version=datasets.Version("1.1.0"), **kwargs)
59
+ self.url = url
60
+ self.features = features
61
+
62
+
63
+ class TruthfulQa(datasets.GeneratorBasedBuilder):
64
+ """TruthfulQA is a benchmark to measure whether a language model is truthful in generating answers to questions."""
65
+
66
+ BUILDER_CONFIGS = [
67
+ TruthfulQaConfig(
68
+ name="generation",
69
+ url="https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/TruthfulQA.csv",
70
+ features=datasets.Features(
71
+ {
72
+ "type": datasets.Value("string"),
73
+ "category": datasets.Value("string"),
74
+ "question": datasets.Value("string"),
75
+ "best_answer": datasets.Value("string"),
76
+ "correct_answers": datasets.features.Sequence(datasets.Value("string")),
77
+ "incorrect_answers": datasets.features.Sequence(datasets.Value("string")),
78
+ "source": datasets.Value("string"),
79
+ }
80
+ ),
81
+ description="The Generation TruthfulQA (main) task tests a model's ability to generate 1-2 sentence answers for a given question truthfully.",
82
+ ),
83
+ TruthfulQaConfig(
84
+ name="multiple_choice",
85
+ url="https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/data/mc_task.json",
86
+ features=datasets.Features(
87
+ {
88
+ "question": datasets.Value("string"),
89
+ "mc1_targets": {
90
+ "choices": datasets.features.Sequence(datasets.Value("string")),
91
+ "labels": datasets.features.Sequence(datasets.Value("int32")),
92
+ },
93
+ "mc2_targets": {
94
+ "choices": datasets.features.Sequence(datasets.Value("string")),
95
+ "labels": datasets.features.Sequence(datasets.Value("int32")),
96
+ },
97
+ }
98
+ ),
99
+ description="The Multiple-Choice TruthfulQA task provides a multiple-choice option to test a model's ability to identify true statements.",
100
+ ),
101
+ ]
102
+
103
+ def _info(self):
104
+ return datasets.DatasetInfo(
105
+ description=_DESCRIPTION,
106
+ features=self.config.features,
107
+ homepage=_HOMEPAGE,
108
+ license=_LICENSE,
109
+ citation=_CITATION,
110
+ )
111
+
112
+ def _split_generators(self, dl_manager):
113
+ data_dir = dl_manager.download(self.config.url)
114
+ return [
115
+ datasets.SplitGenerator(
116
+ name=datasets.Split.VALIDATION,
117
+ gen_kwargs={
118
+ "filepath": data_dir,
119
+ },
120
+ ),
121
+ ]
122
+
123
+ def _split_csv_list(self, csv_list: str, delimiter: str = ";") -> str:
124
+ """
125
+ Splits a csv list field, delimited by `delimiter` (';'), into a list
126
+ of strings.
127
+ """
128
+ csv_list = csv_list.strip().split(delimiter)
129
+ return [item.strip() for item in csv_list]
130
+
131
+ def _generate_examples(self, filepath):
132
+ if self.config.name == "multiple_choice":
133
+ # Multiple choice data is in a `JSON` file.
134
+ with open(filepath, encoding="utf-8") as f:
135
+ contents = json.load(f)
136
+ for key, row in enumerate(contents):
137
+ yield key, {
138
+ "question": row["question"],
139
+ "mc1_targets": {
140
+ "choices": list(row["mc1_targets"].keys()),
141
+ "labels": list(row["mc1_targets"].values()),
142
+ },
143
+ "mc2_targets": {
144
+ "choices": list(row["mc2_targets"].keys()),
145
+ "labels": list(row["mc2_targets"].values()),
146
+ },
147
+ }
148
+ else:
149
+ # Generation data is in a `CSV` file.
150
+ with open(filepath, newline="", encoding="utf-8-sig") as f:
151
+ contents = csv.DictReader(f)
152
+ for key, row in enumerate(contents):
153
+ # Ensure that references exist.
154
+ if not row["Correct Answers"] or not row["Incorrect Answers"]:
155
+ continue
156
+ yield key, {
157
+ "type": row["Type"],
158
+ "category": row["Category"],
159
+ "question": row["Question"],
160
+ "best_answer": row["Best Answer"],
161
+ "correct_answers": self._split_csv_list(row["Correct Answers"]),
162
+ "incorrect_answers": self._split_csv_list(row["Incorrect Answers"]),
163
+ "source": row["Source"],
164
+ }