SoundOfWater / shared /utils /classification.py
bpiyush's picture
Upload folder using huggingface_hub
eafbf97 verified
"""Helper functions for classification tasks."""
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
def plot_metric_curve(
xvalues, yvalues, thresholds, title=None,
figsize=(8, 7), show_thresholds=True, show_legend=True,
ylabel='X', xlabel='Y', ax=None, text_delta=0.01,
label="Metric Curve", color="royalblue", show=False,
fill=None,
):
"""Plot a metric curve, e.g., PR curve or ROC curve."""
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=figsize)
ax.grid(alpha=0.3)
ax.set_title(title)
ax.set_ylabel(ylabel)
ax.set_xlabel(xlabel)
ax.plot(xvalues, yvalues, marker='o', label=label, color=color)
ax.set_xlim(-0.08, 1.08)
ax.set_ylim(-0.08, 1.08)
if fill is not None:
yticks = ax.get_yticks()
ax.fill_between(xvalues, yvalues, "", alpha=0.08, color=color)
# Add `fill` inside the curve
# Find a single (x, y) s.t. it is inside the curve
ax.text(0.4, 0.5, fill, color=color)
ax.set_yticks(yticks)
ax.set_yticklabels([f"{y:.1f}" for y in yticks])
ax.set_ylim(-0.08, 1.08)
# Show thresholds
if show_thresholds:
for x, y, t in zip(xvalues, yvalues, thresholds):
ax.text(x + text_delta, y + text_delta, np.round(t, 2), color=color, alpha=0.5)
if show_legend:
ax.legend()
if show:
plt.show()