climateGAN / figures /human_evaluation.py
vict0rsch's picture
initial commit from cc-ai/climateGAN
448ebbd
"""
This script plots the result of the human evaluation on Amazon Mechanical Turk, where
human participants chose between an image from ClimateGAN or from a different method.
"""
print("Imports...", end="")
from argparse import ArgumentParser
import os
import yaml
import numpy as np
import pandas as pd
import seaborn as sns
from pathlib import Path
import matplotlib.pyplot as plt
# -----------------------
# ----- Constants -----
# -----------------------
comparables_dict = {
"munit_flooded": "MUNIT",
"cyclegan": "CycleGAN",
"instagan": "InstaGAN",
"instagan_copypaste": "Mask-InstaGAN",
"painted_ground": "Painted ground",
}
# Colors
palette_colorblind = sns.color_palette("colorblind")
color_climategan = palette_colorblind[9]
palette_colorblind = sns.color_palette("colorblind")
color_munit = palette_colorblind[1]
color_cyclegan = palette_colorblind[2]
color_instagan = palette_colorblind[3]
color_maskinstagan = palette_colorblind[6]
color_paintedground = palette_colorblind[8]
palette_comparables = [
color_munit,
color_cyclegan,
color_instagan,
color_maskinstagan,
color_paintedground,
]
palette_comparables_light = [
sns.light_palette(color, n_colors=3)[1] for color in palette_comparables
]
def parsed_args():
"""
Parse and returns command-line args
Returns:
argparse.Namespace: the parsed arguments
"""
parser = ArgumentParser()
parser.add_argument(
"--input_csv",
default="amt_omni-vs-other.csv",
type=str,
help="CSV containing the results of the human evaluation, pre-processed",
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
help="Output directory",
)
parser.add_argument(
"--dpi",
default=200,
type=int,
help="DPI for the output images",
)
parser.add_argument(
"--n_bs",
default=1e6,
type=int,
help="Number of bootrstrap samples",
)
parser.add_argument(
"--bs_seed",
default=17,
type=int,
help="Bootstrap random seed, for reproducibility",
)
return parser.parse_args()
if __name__ == "__main__":
# -----------------------------
# ----- Parse arguments -----
# -----------------------------
args = parsed_args()
print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
# Determine output dir
if args.output_dir is None:
output_dir = Path(os.environ["SLURM_TMPDIR"])
else:
output_dir = Path(args.output_dir)
if not output_dir.exists():
output_dir.mkdir(parents=True, exist_ok=False)
# Store args
output_yml = output_dir / "args_human_evaluation.yml"
with open(output_yml, "w") as f:
yaml.dump(vars(args), f)
# Read CSV
df = pd.read_csv(args.input_csv)
# Sort Y labels
comparables = df.comparable.unique()
is_climategan_sum = [
df.loc[df.comparable == c, "climategan"].sum() for c in comparables
]
comparables = comparables[np.argsort(is_climategan_sum)[::-1]]
# Plot setup
sns.set(style="whitegrid")
plt.rcParams.update({"font.family": "serif"})
plt.rcParams.update(
{
"font.serif": [
"Computer Modern Roman",
"Times New Roman",
"Utopia",
"New Century Schoolbook",
"Century Schoolbook L",
"ITC Bookman",
"Bookman",
"Times",
"Palatino",
"Charter",
"serif" "Bitstream Vera Serif",
"DejaVu Serif",
]
}
)
fontsize = "medium"
# Initialize the matplotlib figure
fig, ax = plt.subplots(figsize=(10.5, 3), dpi=args.dpi)
# Plot the total (right)
sns.barplot(
data=df.loc[df.is_valid],
x="is_valid",
y="comparable",
order=comparables,
orient="h",
label="comparable",
palette=palette_comparables_light,
ci=None,
)
# Plot the left
sns.barplot(
data=df.loc[df.is_valid],
x="climategan",
y="comparable",
order=comparables,
orient="h",
label="climategan",
color=color_climategan,
ci=99,
n_boot=args.n_bs,
seed=args.bs_seed,
errcolor="black",
errwidth=1.5,
capsize=0.1,
)
# Draw line at 0.5
y = np.arange(ax.get_ylim()[1] + 0.1, ax.get_ylim()[0], 0.1)
x = 0.5 * np.ones(y.shape[0])
ax.plot(x, y, linestyle=":", linewidth=1.5, color="black")
# Change Y-Tick labels
yticklabels = [comparables_dict[ytick.get_text()] for ytick in ax.get_yticklabels()]
yticklabels_text = ax.set_yticklabels(
yticklabels, fontsize=fontsize, horizontalalignment="right", x=0.96
)
for ytl in yticklabels_text:
ax.add_artist(ytl)
# Remove Y-label
ax.set_ylabel(ylabel="")
# Change X-Tick labels
xlim = [0.0, 1.1]
xticks = np.arange(xlim[0], xlim[1], 0.1)
ax.set(xticks=xticks)
plt.setp(ax.get_xticklabels(), fontsize=fontsize)
# Set X-label
ax.set_xlabel(None)
# Change spines
sns.despine(left=True, bottom=True)
# Save figure
output_fig = output_dir / "human_evaluation_rate_climategan.png"
fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")