Safetensors
bert
File size: 2,927 Bytes
8912001
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def evaluate_metrics(labels, probs, target_cols, threshold=0.5):
    """
    Evaluate metrics for each label

    Parameters:
    -----------
    labels : numpy.ndarray
        Ground truth labels (0 or 1)
    probs : numpy.ndarray
        Prediction probabilities
    target_cols : list
        List of target disease names
    threshold : float
        Threshold for converting probabilities to binary predictions
    
    Returns:
    --------
    pandas.DataFrame
        Metrics for each label
    """
    # Convert probabilities to binary predictions
    preds = (probs >= threshold).astype(int)
    
    # Dictionary to store metrics
    metrics_dict = {
        'Disease': target_cols,
        'Positive_Samples': [],
        'Accuracy': [],
        'Precision': [],
        'Recall': [],
        'F1': [],
        'AUC-ROC': [],
        'AP': []  # Average Precision
    }
    
    # Calculate metrics for each label
    for i in range(len(target_cols)):
        metrics_dict['Positive_Samples'].append(np.sum(labels[:, i]))
        metrics_dict['Accuracy'].append(accuracy_score(labels[:, i], preds[:, i]))
        metrics_dict['Precision'].append(precision_score(labels[:, i], preds[:, i], zero_division=0))
        metrics_dict['Recall'].append(recall_score(labels[:, i], preds[:, i], zero_division=0))
        metrics_dict['F1'].append(f1_score(labels[:, i], preds[:, i], zero_division=0))
        metrics_dict['AUC-ROC'].append(roc_auc_score(labels[:, i], probs[:, i]))
        metrics_dict['AP'].append(average_precision_score(labels[:, i], probs[:, i]))
    
    # Convert to DataFrame
    metrics_df = pd.DataFrame(metrics_dict)
    
    # Round numerical values to 3 decimal places
    numeric_cols = ['Accuracy', 'Precision', 'Recall', 'F1', 'AUC-ROC', 'AP']
    metrics_df[numeric_cols] = metrics_df[numeric_cols].round(3)
    
    return metrics_df

def plot_metrics_heatmap(metrics_df, metric_cols=['Precision', 'Recall', 'F1', 'AUC-ROC', 'AP']):
    """
    Plot a heatmap of evaluation metrics

    Parameters:
    -----------
    metrics_df : pandas.DataFrame
        Output from evaluate_metrics function
    metric_cols : list
        Metrics to display in the heatmap
    """
    plt.figure(figsize=(12, len(target_cols)//2))
    
    # Prepare data for heatmap
    heatmap_data = metrics_df[metric_cols].values
    
    # Draw heatmap
    sns.heatmap(heatmap_data, 
                annot=True, 
                fmt='.3f',
                cmap='YlOrRd',
                xticklabels=metric_cols,
                yticklabels=metrics_df['Disease'],
                vmin=0, 
                vmax=1)
    
    plt.title('Evaluation Metrics Heatmap')
    plt.tight_layout()
    plt.show()