Spaces:
Running
Running
File size: 4,546 Bytes
39b7b6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
from typing import Optional
import numpy as np
import weave
class BaseAccuracyMetric(weave.Scorer):
"""
BaseAccuracyMetric is a class that extends the
[`weave.Scorer`](https://weave-docs.wandb.ai/guides/evaluation/scorers#class-based-scorers)
to provide a comprehensive evaluation of accuracy metrics for a given set of score rows.
This class is designed to process a list of score rows, each containing a
'correct' key that indicates whether a particular prediction was correct.
The `summarize` method calculates various statistical measures and metrics
based on this data, including:
- True and false counts: The number of true and false predictions.
- True and false fractions: The proportion of true and false predictions.
- Standard error: The standard error of the mean for the true predictions.
- Precision: The ratio of true positive predictions to the total number of
positive predictions.
- Recall: The ratio of true positive predictions to the total number of
actual positives.
- F1 Score: The harmonic mean of precision and recall, providing a balance
between the two metrics.
The `summarize` method returns a dictionary containing these metrics,
allowing for a detailed analysis of the model's performance.
Methods:
summarize(score_rows: list) -> Optional[dict]:
Processes the input score rows to compute and return a dictionary
of accuracy metrics.
"""
@weave.op()
def summarize(self, score_rows: list) -> Optional[dict]:
"""
Summarizes the accuracy metrics from a list of score rows.
This method processes a list of score rows, each containing a 'correct' key
that indicates whether a particular prediction was correct. It calculates
various statistical measures and metrics based on this data, including:
- True and false counts: The number of true and false predictions.
- True and false fractions: The proportion of true and false predictions.
- Standard error: The standard error of the mean for the true predictions.
- Precision: The ratio of true positive predictions to the total number of
positive predictions.
- Recall: The ratio of true positive predictions to the total number of
actual positives.
- F1 Score: The harmonic mean of precision and recall, providing a balance
between the two metrics.
The method returns a dictionary containing these metrics, allowing for a
detailed analysis of the model's performance.
Args:
score_rows (list): A list of dictionaries, each containing a 'correct'
key with a boolean value indicating the correctness of a prediction.
Returns:
Optional[dict]: A dictionary containing the calculated accuracy metrics,
or None if the input list is empty.
"""
valid_data = [
x.get("correct") for x in score_rows if x.get("correct") is not None
]
count_true = list(valid_data).count(True)
int_data = [int(x) for x in valid_data]
sample_mean = np.mean(int_data) if int_data else 0
sample_variance = np.var(int_data) if int_data else 0
sample_error = np.sqrt(sample_variance / len(int_data)) if int_data else 0
# Calculate precision, recall, and F1 score
true_positives = count_true
false_positives = len(valid_data) - count_true
false_negatives = len(score_rows) - len(valid_data)
precision = (
true_positives / (true_positives + false_positives)
if (true_positives + false_positives) > 0
else 0
)
recall = (
true_positives / (true_positives + false_negatives)
if (true_positives + false_negatives) > 0
else 0
)
f1_score = (
(2 * precision * recall) / (precision + recall)
if (precision + recall) > 0
else 0
)
return {
"correct": {
"true_count": count_true,
"false_count": len(score_rows) - count_true,
"true_fraction": float(sample_mean),
"false_fraction": 1.0 - float(sample_mean),
"stderr": float(sample_error),
"precision": precision,
"recall": recall,
"f1_score": f1_score,
}
}
|