Corey Morris commited on
Commit
e03b231
·
1 Parent(s): abac22e

MC1 column had 8 rows with a value of 1. It didn't make sense given the next highest value was 0.47 . Assuming they were data errors, they were removed

Browse files
result_data_processor.py CHANGED
@@ -48,6 +48,19 @@ class ResultDataProcessor:
48
  df.index = (df.index.str.replace('mc\|0', 'mc2', regex=True))
49
  df = df.loc[['harness|truthfulqa:mc2']]
50
  return df[[model_name]]
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
 
53
  @staticmethod
@@ -119,6 +132,9 @@ class ResultDataProcessor:
119
  cols = cols[-1:] + cols[:-1]
120
  data = data[cols]
121
 
 
 
 
122
  return data
123
 
124
  def rank_data(self):
 
48
  df.index = (df.index.str.replace('mc\|0', 'mc2', regex=True))
49
  df = df.loc[['harness|truthfulqa:mc2']]
50
  return df[[model_name]]
51
+
52
+ # remove extreme outliers from column harness|truthfulqa:mc1
53
+ def _remove_mc1_outliers(self, df):
54
+ mc1 = df['harness|truthfulqa:mc1']
55
+ # Identify the outliers
56
+ # outliers_condition = mc1 > mc1.quantile(.95)
57
+ outliers_condition = mc1 == 1.0
58
+ # Print out the number of outliers
59
+ print('Number of outliers: ', outliers_condition.sum())
60
+ # Replace the outliers with NaN
61
+ df.loc[outliers_condition, 'harness|truthfulqa:mc1'] = np.nan
62
+ return df
63
+
64
 
65
 
66
  @staticmethod
 
132
  cols = cols[-1:] + cols[:-1]
133
  data = data[cols]
134
 
135
+ # remove extreme outliers from column harness|truthfulqa:mc1
136
+ data = self._remove_mc1_outliers(data)
137
+
138
  return data
139
 
140
  def rank_data(self):
test_data_processing.py CHANGED
@@ -34,6 +34,13 @@ class TestResultDataProcessor(unittest.TestCase):
34
  def test_truthfulqa_mc(self):
35
  data = self.processor.data
36
  self.assertNotIn('truthfulqa:mc', data.columns)
 
 
 
 
 
 
 
37
 
38
  if __name__ == '__main__':
39
  unittest.main()
 
34
  def test_truthfulqa_mc(self):
35
  data = self.processor.data
36
  self.assertNotIn('truthfulqa:mc', data.columns)
37
+
38
+ # check for extreme outliers in mc1 column
39
+ def test_mc1_outliers(self):
40
+ data = self.processor.data
41
+ mc1 = data['harness|truthfulqa:mc1']
42
+ self.assertLess(mc1.max(), 1.0)
43
+ self.assertGreater(mc1.min(), 0.0)
44
 
45
  if __name__ == '__main__':
46
  unittest.main()