Spaces:
Running
Running
"""Helper functions for classification tasks.""" | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import numpy as np | |
def plot_metric_curve( | |
xvalues, yvalues, thresholds, title=None, | |
figsize=(8, 7), show_thresholds=True, show_legend=True, | |
ylabel='X', xlabel='Y', ax=None, text_delta=0.01, | |
label="Metric Curve", color="royalblue", show=False, | |
fill=None, | |
): | |
"""Plot a metric curve, e.g., PR curve or ROC curve.""" | |
if ax is None: | |
fig, ax = plt.subplots(1, 1, figsize=figsize) | |
ax.grid(alpha=0.3) | |
ax.set_title(title) | |
ax.set_ylabel(ylabel) | |
ax.set_xlabel(xlabel) | |
ax.plot(xvalues, yvalues, marker='o', label=label, color=color) | |
ax.set_xlim(-0.08, 1.08) | |
ax.set_ylim(-0.08, 1.08) | |
if fill is not None: | |
yticks = ax.get_yticks() | |
ax.fill_between(xvalues, yvalues, "", alpha=0.08, color=color) | |
# Add `fill` inside the curve | |
# Find a single (x, y) s.t. it is inside the curve | |
ax.text(0.4, 0.5, fill, color=color) | |
ax.set_yticks(yticks) | |
ax.set_yticklabels([f"{y:.1f}" for y in yticks]) | |
ax.set_ylim(-0.08, 1.08) | |
# Show thresholds | |
if show_thresholds: | |
for x, y, t in zip(xvalues, yvalues, thresholds): | |
ax.text(x + text_delta, y + text_delta, np.round(t, 2), color=color, alpha=0.5) | |
if show_legend: | |
ax.legend() | |
if show: | |
plt.show() | |