Spaces:
Running
Running
File size: 1,448 Bytes
eafbf97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
"""Helper functions for classification tasks."""
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
def plot_metric_curve(
xvalues, yvalues, thresholds, title=None,
figsize=(8, 7), show_thresholds=True, show_legend=True,
ylabel='X', xlabel='Y', ax=None, text_delta=0.01,
label="Metric Curve", color="royalblue", show=False,
fill=None,
):
"""Plot a metric curve, e.g., PR curve or ROC curve."""
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=figsize)
ax.grid(alpha=0.3)
ax.set_title(title)
ax.set_ylabel(ylabel)
ax.set_xlabel(xlabel)
ax.plot(xvalues, yvalues, marker='o', label=label, color=color)
ax.set_xlim(-0.08, 1.08)
ax.set_ylim(-0.08, 1.08)
if fill is not None:
yticks = ax.get_yticks()
ax.fill_between(xvalues, yvalues, "", alpha=0.08, color=color)
# Add `fill` inside the curve
# Find a single (x, y) s.t. it is inside the curve
ax.text(0.4, 0.5, fill, color=color)
ax.set_yticks(yticks)
ax.set_yticklabels([f"{y:.1f}" for y in yticks])
ax.set_ylim(-0.08, 1.08)
# Show thresholds
if show_thresholds:
for x, y, t in zip(xvalues, yvalues, thresholds):
ax.text(x + text_delta, y + text_delta, np.round(t, 2), color=color, alpha=0.5)
if show_legend:
ax.legend()
if show:
plt.show()
|