import numpy as np from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score import pandas as pd import matplotlib.pyplot as plt import seaborn as sns def evaluate_metrics(labels, probs, target_cols, threshold=0.5): """ Evaluate metrics for each label Parameters: ----------- labels : numpy.ndarray Ground truth labels (0 or 1) probs : numpy.ndarray Prediction probabilities target_cols : list List of target disease names threshold : float Threshold for converting probabilities to binary predictions Returns: -------- pandas.DataFrame Metrics for each label """ # Convert probabilities to binary predictions preds = (probs >= threshold).astype(int) # Dictionary to store metrics metrics_dict = { 'Disease': target_cols, 'Positive_Samples': [], 'Accuracy': [], 'Precision': [], 'Recall': [], 'F1': [], 'AUC-ROC': [], 'AP': [] # Average Precision } # Calculate metrics for each label for i in range(len(target_cols)): metrics_dict['Positive_Samples'].append(np.sum(labels[:, i])) metrics_dict['Accuracy'].append(accuracy_score(labels[:, i], preds[:, i])) metrics_dict['Precision'].append(precision_score(labels[:, i], preds[:, i], zero_division=0)) metrics_dict['Recall'].append(recall_score(labels[:, i], preds[:, i], zero_division=0)) metrics_dict['F1'].append(f1_score(labels[:, i], preds[:, i], zero_division=0)) metrics_dict['AUC-ROC'].append(roc_auc_score(labels[:, i], probs[:, i])) metrics_dict['AP'].append(average_precision_score(labels[:, i], probs[:, i])) # Convert to DataFrame metrics_df = pd.DataFrame(metrics_dict) # Round numerical values to 3 decimal places numeric_cols = ['Accuracy', 'Precision', 'Recall', 'F1', 'AUC-ROC', 'AP'] metrics_df[numeric_cols] = metrics_df[numeric_cols].round(3) return metrics_df def plot_metrics_heatmap(metrics_df, metric_cols=['Precision', 'Recall', 'F1', 'AUC-ROC', 'AP']): """ Plot a heatmap of evaluation metrics Parameters: ----------- metrics_df : pandas.DataFrame Output from evaluate_metrics function metric_cols : list Metrics to display in the heatmap """ plt.figure(figsize=(12, len(target_cols)//2)) # Prepare data for heatmap heatmap_data = metrics_df[metric_cols].values # Draw heatmap sns.heatmap(heatmap_data, annot=True, fmt='.3f', cmap='YlOrRd', xticklabels=metric_cols, yticklabels=metrics_df['Disease'], vmin=0, vmax=1) plt.title('Evaluation Metrics Heatmap') plt.tight_layout() plt.show()