Safetensors
bert
YYama0 commited on
Commit
8912001
·
verified ·
1 Parent(s): 39426b5

Upload evaluate.py

Browse files
Files changed (1) hide show
  1. evaluate.py +89 -0
evaluate.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+
7
+ def evaluate_metrics(labels, probs, target_cols, threshold=0.5):
8
+ """
9
+ Evaluate metrics for each label
10
+
11
+ Parameters:
12
+ -----------
13
+ labels : numpy.ndarray
14
+ Ground truth labels (0 or 1)
15
+ probs : numpy.ndarray
16
+ Prediction probabilities
17
+ target_cols : list
18
+ List of target disease names
19
+ threshold : float
20
+ Threshold for converting probabilities to binary predictions
21
+
22
+ Returns:
23
+ --------
24
+ pandas.DataFrame
25
+ Metrics for each label
26
+ """
27
+ # Convert probabilities to binary predictions
28
+ preds = (probs >= threshold).astype(int)
29
+
30
+ # Dictionary to store metrics
31
+ metrics_dict = {
32
+ 'Disease': target_cols,
33
+ 'Positive_Samples': [],
34
+ 'Accuracy': [],
35
+ 'Precision': [],
36
+ 'Recall': [],
37
+ 'F1': [],
38
+ 'AUC-ROC': [],
39
+ 'AP': [] # Average Precision
40
+ }
41
+
42
+ # Calculate metrics for each label
43
+ for i in range(len(target_cols)):
44
+ metrics_dict['Positive_Samples'].append(np.sum(labels[:, i]))
45
+ metrics_dict['Accuracy'].append(accuracy_score(labels[:, i], preds[:, i]))
46
+ metrics_dict['Precision'].append(precision_score(labels[:, i], preds[:, i], zero_division=0))
47
+ metrics_dict['Recall'].append(recall_score(labels[:, i], preds[:, i], zero_division=0))
48
+ metrics_dict['F1'].append(f1_score(labels[:, i], preds[:, i], zero_division=0))
49
+ metrics_dict['AUC-ROC'].append(roc_auc_score(labels[:, i], probs[:, i]))
50
+ metrics_dict['AP'].append(average_precision_score(labels[:, i], probs[:, i]))
51
+
52
+ # Convert to DataFrame
53
+ metrics_df = pd.DataFrame(metrics_dict)
54
+
55
+ # Round numerical values to 3 decimal places
56
+ numeric_cols = ['Accuracy', 'Precision', 'Recall', 'F1', 'AUC-ROC', 'AP']
57
+ metrics_df[numeric_cols] = metrics_df[numeric_cols].round(3)
58
+
59
+ return metrics_df
60
+
61
+ def plot_metrics_heatmap(metrics_df, metric_cols=['Precision', 'Recall', 'F1', 'AUC-ROC', 'AP']):
62
+ """
63
+ Plot a heatmap of evaluation metrics
64
+
65
+ Parameters:
66
+ -----------
67
+ metrics_df : pandas.DataFrame
68
+ Output from evaluate_metrics function
69
+ metric_cols : list
70
+ Metrics to display in the heatmap
71
+ """
72
+ plt.figure(figsize=(12, len(target_cols)//2))
73
+
74
+ # Prepare data for heatmap
75
+ heatmap_data = metrics_df[metric_cols].values
76
+
77
+ # Draw heatmap
78
+ sns.heatmap(heatmap_data,
79
+ annot=True,
80
+ fmt='.3f',
81
+ cmap='YlOrRd',
82
+ xticklabels=metric_cols,
83
+ yticklabels=metrics_df['Disease'],
84
+ vmin=0,
85
+ vmax=1)
86
+
87
+ plt.title('Evaluation Metrics Heatmap')
88
+ plt.tight_layout()
89
+ plt.show()