sapphomoon
commited on
Commit
·
afb83aa
1
Parent(s):
ce3fe82
Upload truthful_qa.py
Browse files- truthful_qa.py +164 -0
truthful_qa.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""TruthfulQA dataset."""
|
15 |
+
|
16 |
+
|
17 |
+
import csv
|
18 |
+
import json
|
19 |
+
|
20 |
+
import datasets
|
21 |
+
|
22 |
+
|
23 |
+
_CITATION = """\
|
24 |
+
@misc{lin2021truthfulqa,
|
25 |
+
title={TruthfulQA: Measuring How Models Mimic Human Falsehoods},
|
26 |
+
author={Stephanie Lin and Jacob Hilton and Owain Evans},
|
27 |
+
year={2021},
|
28 |
+
eprint={2109.07958},
|
29 |
+
archivePrefix={arXiv},
|
30 |
+
primaryClass={cs.CL}
|
31 |
+
}
|
32 |
+
"""
|
33 |
+
|
34 |
+
_DESCRIPTION = """\
|
35 |
+
TruthfulQA is a benchmark to measure whether a language model is truthful in
|
36 |
+
generating answers to questions. The benchmark comprises 817 questions that
|
37 |
+
span 38 categories, including health, law, finance and politics. Questions are
|
38 |
+
crafted so that some humans would answer falsely due to a false belief or
|
39 |
+
misconception. To perform well, models must avoid generating false answers
|
40 |
+
learned from imitating human texts.
|
41 |
+
"""
|
42 |
+
|
43 |
+
_HOMEPAGE = "https://github.com/sylinrl/TruthfulQA"
|
44 |
+
|
45 |
+
_LICENSE = "Apache License 2.0"
|
46 |
+
|
47 |
+
|
48 |
+
class TruthfulQaConfig(datasets.BuilderConfig):
|
49 |
+
"""BuilderConfig for TruthfulQA."""
|
50 |
+
|
51 |
+
def __init__(self, url, features, **kwargs):
|
52 |
+
"""BuilderConfig for TruthfulQA.
|
53 |
+
Args:
|
54 |
+
url: *string*, the url to the configuration's data.
|
55 |
+
features: *list[string]*, list of features that'll appear in the feature dict.
|
56 |
+
**kwargs: keyword arguments forwarded to super.
|
57 |
+
"""
|
58 |
+
super().__init__(version=datasets.Version("1.1.0"), **kwargs)
|
59 |
+
self.url = url
|
60 |
+
self.features = features
|
61 |
+
|
62 |
+
|
63 |
+
class TruthfulQa(datasets.GeneratorBasedBuilder):
|
64 |
+
"""TruthfulQA is a benchmark to measure whether a language model is truthful in generating answers to questions."""
|
65 |
+
|
66 |
+
BUILDER_CONFIGS = [
|
67 |
+
TruthfulQaConfig(
|
68 |
+
name="generation",
|
69 |
+
url="https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/TruthfulQA.csv",
|
70 |
+
features=datasets.Features(
|
71 |
+
{
|
72 |
+
"type": datasets.Value("string"),
|
73 |
+
"category": datasets.Value("string"),
|
74 |
+
"question": datasets.Value("string"),
|
75 |
+
"best_answer": datasets.Value("string"),
|
76 |
+
"correct_answers": datasets.features.Sequence(datasets.Value("string")),
|
77 |
+
"incorrect_answers": datasets.features.Sequence(datasets.Value("string")),
|
78 |
+
"source": datasets.Value("string"),
|
79 |
+
}
|
80 |
+
),
|
81 |
+
description="The Generation TruthfulQA (main) task tests a model's ability to generate 1-2 sentence answers for a given question truthfully.",
|
82 |
+
),
|
83 |
+
TruthfulQaConfig(
|
84 |
+
name="multiple_choice",
|
85 |
+
url="https://raw.githubusercontent.com/sylinrl/TruthfulQA/013686a06be7a7bde5bf8223943e106c7250123c/data/mc_task.json",
|
86 |
+
features=datasets.Features(
|
87 |
+
{
|
88 |
+
"question": datasets.Value("string"),
|
89 |
+
"mc1_targets": {
|
90 |
+
"choices": datasets.features.Sequence(datasets.Value("string")),
|
91 |
+
"labels": datasets.features.Sequence(datasets.Value("int32")),
|
92 |
+
},
|
93 |
+
"mc2_targets": {
|
94 |
+
"choices": datasets.features.Sequence(datasets.Value("string")),
|
95 |
+
"labels": datasets.features.Sequence(datasets.Value("int32")),
|
96 |
+
},
|
97 |
+
}
|
98 |
+
),
|
99 |
+
description="The Multiple-Choice TruthfulQA task provides a multiple-choice option to test a model's ability to identify true statements.",
|
100 |
+
),
|
101 |
+
]
|
102 |
+
|
103 |
+
def _info(self):
|
104 |
+
return datasets.DatasetInfo(
|
105 |
+
description=_DESCRIPTION,
|
106 |
+
features=self.config.features,
|
107 |
+
homepage=_HOMEPAGE,
|
108 |
+
license=_LICENSE,
|
109 |
+
citation=_CITATION,
|
110 |
+
)
|
111 |
+
|
112 |
+
def _split_generators(self, dl_manager):
|
113 |
+
data_dir = dl_manager.download(self.config.url)
|
114 |
+
return [
|
115 |
+
datasets.SplitGenerator(
|
116 |
+
name=datasets.Split.VALIDATION,
|
117 |
+
gen_kwargs={
|
118 |
+
"filepath": data_dir,
|
119 |
+
},
|
120 |
+
),
|
121 |
+
]
|
122 |
+
|
123 |
+
def _split_csv_list(self, csv_list: str, delimiter: str = ";") -> str:
|
124 |
+
"""
|
125 |
+
Splits a csv list field, delimited by `delimiter` (';'), into a list
|
126 |
+
of strings.
|
127 |
+
"""
|
128 |
+
csv_list = csv_list.strip().split(delimiter)
|
129 |
+
return [item.strip() for item in csv_list]
|
130 |
+
|
131 |
+
def _generate_examples(self, filepath):
|
132 |
+
if self.config.name == "multiple_choice":
|
133 |
+
# Multiple choice data is in a `JSON` file.
|
134 |
+
with open(filepath, encoding="utf-8") as f:
|
135 |
+
contents = json.load(f)
|
136 |
+
for key, row in enumerate(contents):
|
137 |
+
yield key, {
|
138 |
+
"question": row["question"],
|
139 |
+
"mc1_targets": {
|
140 |
+
"choices": list(row["mc1_targets"].keys()),
|
141 |
+
"labels": list(row["mc1_targets"].values()),
|
142 |
+
},
|
143 |
+
"mc2_targets": {
|
144 |
+
"choices": list(row["mc2_targets"].keys()),
|
145 |
+
"labels": list(row["mc2_targets"].values()),
|
146 |
+
},
|
147 |
+
}
|
148 |
+
else:
|
149 |
+
# Generation data is in a `CSV` file.
|
150 |
+
with open(filepath, newline="", encoding="utf-8-sig") as f:
|
151 |
+
contents = csv.DictReader(f)
|
152 |
+
for key, row in enumerate(contents):
|
153 |
+
# Ensure that references exist.
|
154 |
+
if not row["Correct Answers"] or not row["Incorrect Answers"]:
|
155 |
+
continue
|
156 |
+
yield key, {
|
157 |
+
"type": row["Type"],
|
158 |
+
"category": row["Category"],
|
159 |
+
"question": row["Question"],
|
160 |
+
"best_answer": row["Best Answer"],
|
161 |
+
"correct_answers": self._split_csv_list(row["Correct Answers"]),
|
162 |
+
"incorrect_answers": self._split_csv_list(row["Incorrect Answers"]),
|
163 |
+
"source": row["Source"],
|
164 |
+
}
|