Spaces:
Runtime error
Runtime error
import gradio as gr | |
import json | |
import numpy as np | |
import pandas as pd | |
from datasets import load_from_disk | |
from itertools import chain | |
import operator | |
pd.options.plotting.backend = "plotly" | |
TITLE = "Identity Biases in Diffusion Models: Professions" | |
_INTRO = """ | |
# Identity Biases in Diffusion Models: Professions | |
Explore profession-level social biases in the data from [DiffusionBiasExplorer](https://hf.co/spaces/tti-bias/diffusion-bias-explorer)! | |
This demo leverages the gender and ethnicity representation clusters described in the [companion app](https://hf.co/spaces/tti-bias/diffusion-face-clustering) | |
to analyze social trends in machine-generated visual representations of professions. | |
The **Professions Overview** tab lets you compare the distribution over | |
[identity clusters](https://hf.co/spaces/tti-bias/diffusion-face-clustering "Identity clusters identify visual features in the systems' output space correlated with variation of gender and ethnicity in input prompts.") | |
across professions for Stable Diffusion and Dalle-2 systems (or aggregated for `All Models`). | |
The **Professions Focus** tab provides more details for each of the individual professions, including direct system comparisons and examples of profession images for each cluster. | |
This work was done in the scope of the [Stable Bias Project](https://hf.co/spaces/tti-bias/stable-bias). | |
""" | |
_ = """ | |
For example, you can use this tool to investigate: | |
- How do each model's representation of professions correlate with the gender ratios reported by the [U.S. Bureau of Labor | |
Statistics](https://www.bls.gov/cps/cpsaat11.htm "The reported percentage of women in each profession in the US is indicated in the `Labor Women` column in the Professions Overview tab.")? | |
Are social trends reflected, are they exaggerated? | |
- Which professions have the starkest differences in how different models represent them? | |
""" | |
professions_dset = load_from_disk("professions") | |
professions_df = professions_dset.to_pandas() | |
clusters_dicts = dict( | |
(num_cl, json.load(open(f"clusters/professions_to_clusters_{num_cl}.json"))) | |
for num_cl in [12, 24, 48] | |
) | |
cluster_summaries_by_size = json.load(open("clusters/cluster_summaries_by_size.json")) | |
prompts = pd.read_csv("promptsadjectives.csv") | |
professions = ["all professions"] + list( | |
# sorted([p.lower() for p in prompts["Occupation-Noun"].tolist()]) | |
sorted([p for p in prompts["Occupation-Noun"].tolist()]) | |
) | |
models = { | |
"All": "All Models", | |
"SD_14": "Stable Diffusion 1.4", | |
"SD_2": "Stable Diffusion 2", | |
"DallE": "Dall-E 2", | |
} | |
df_models = { | |
"All Models": "All", | |
"Stable Diffusion 1.4": "SD_14", | |
"Stable Diffusion 2": "SD_2", | |
"Dall-E 2": "DallE", | |
} | |
def describe_cluster(num_clusters, block="label"): | |
cl_dict = clusters_dicts[num_clusters] | |
labels_values = sorted(cl_dict.items(), key=operator.itemgetter(1)) | |
labels_values.reverse() | |
total = float(sum(cl_dict.values())) | |
lv_prcnt = list( | |
(item[0], round(item[1] * 100 / total, 0)) for item in labels_values | |
) | |
top_label = lv_prcnt[0][0] | |
description_string = ( | |
"<span>The most represented %s is <b>%s</b>, making up about <b>%d%%</b> of the cluster.</span>" | |
% (to_string(block), to_string(top_label), lv_prcnt[0][1]) | |
) | |
description_string += "<p>This is followed by: " | |
for lv in lv_prcnt[1:]: | |
description_string += "<BR/><b>%s:</b> %d%%" % (to_string(lv[0]), lv[1]) | |
description_string += "</p>" | |
return description_string | |
def make_profession_plot(num_clusters, prof_name): | |
sorted_cl_scores = [ | |
(k, v) | |
for k, v in sorted( | |
clusters_dicts[num_clusters]["All"][prof_name][ | |
"cluster_proportions" | |
].items(), | |
key=lambda x: x[1], | |
reverse=True, | |
) | |
if v > 0 | |
] | |
pre_pandas = dict( | |
[ | |
( | |
models[mod_name], | |
dict( | |
( | |
f"Cluster {k}", | |
clusters_dicts[num_clusters][mod_name][prof_name][ | |
"cluster_proportions" | |
][k], | |
) | |
for k, _ in sorted_cl_scores | |
), | |
) | |
for mod_name in models | |
] | |
) | |
df = pd.DataFrame.from_dict(pre_pandas) | |
prof_plot = df.plot(kind="bar", barmode="group") | |
cl_summary_text = f"Profession '{prof_name}':\n" | |
for cl_id, _ in sorted_cl_scores: | |
cl_summary_text += f"- {cluster_summaries_by_size[str(num_clusters)][int(cl_id)].replace(' gender terms', '').replace('; ethnicity terms:', ',')} \n" | |
return ( | |
prof_plot, | |
gr.update( | |
choices=[k for k, _ in sorted_cl_scores], value=sorted_cl_scores[0][0] | |
), | |
gr.update(value=cl_summary_text), | |
) | |
def make_profession_table(num_clusters, prof_names, mod_name, max_cols=8): | |
professions_list_clusters = [ | |
( | |
prof_name, | |
clusters_dicts[num_clusters][df_models[mod_name]][prof_name][ | |
"cluster_proportions" | |
], | |
) | |
for prof_name in prof_names | |
] | |
totals = sorted( | |
[ | |
( | |
k, | |
sum( | |
prof_clusters[str(k)] | |
for _, prof_clusters in professions_list_clusters | |
), | |
) | |
for k in range(num_clusters) | |
], | |
key=lambda x: x[1], | |
reverse=True, | |
)[:max_cols] | |
prof_list_pre_pandas = [ | |
dict( | |
[ | |
("Profession", prof_name), | |
( | |
"Entropy", | |
clusters_dicts[num_clusters][df_models[mod_name]][prof_name][ | |
"entropy" | |
], | |
), | |
( | |
"Labor Women", | |
clusters_dicts[num_clusters][df_models[mod_name]][prof_name][ | |
"labor_fm" | |
][0], | |
), | |
("", ""), | |
] | |
+ [(f"Cluster {k}", prof_clusters[str(k)]) for k, v in totals if v > 0] | |
) | |
for prof_name, prof_clusters in professions_list_clusters | |
] | |
clusters_df = pd.DataFrame.from_dict(prof_list_pre_pandas) | |
cl_summary_text = "" | |
for cl_id, _ in totals[:max_cols]: | |
cl_summary_text += f"- {cluster_summaries_by_size[str(num_clusters)][cl_id].replace(' gender terms', '').replace('; ethnicity terms:', ',')} \n" | |
return ( | |
[c[0] for c in totals], | |
( | |
clusters_df.style.background_gradient( | |
axis=None, vmin=0, vmax=100, cmap="YlGnBu" | |
) | |
.format(precision=1) | |
.to_html() | |
), | |
gr.update(value=cl_summary_text), | |
) | |
def get_image(model, fname, score): | |
return ( | |
professions_dset.select( | |
professions_df[ | |
(professions_df["image_path"] == fname) | |
& (professions_df["model"] == model) | |
].index | |
)["image"][0], | |
" ".join(fname.split("/")[0].split("_")[4:]) | |
+ f" | {score:.2f}" | |
+ f" | {models[model]}", | |
) | |
def show_examplars(num_clusters, prof_name, cl_id, confidence_threshold=0.6): | |
# only show images where the similarity to the centroid is > confidence_threshold | |
examplars_dict = clusters_dicts[num_clusters]["All"][prof_name][ | |
"cluster_examplars" | |
][str(cl_id)] | |
l = [ | |
tuple(img) | |
for img in examplars_dict["close"] | |
+ examplars_dict["mid"][:2] | |
+ examplars_dict["far"] | |
] | |
l = [ | |
img | |
for i, img in enumerate(l) | |
if img[0] > confidence_threshold and img not in l[:i] | |
] | |
return ( | |
[get_image(model, fname, score) for score, model, fname in l], | |
gr.update( | |
label=f"Generations for profession ''{prof_name}'' assigned to cluster {cl_id} of {num_clusters}" | |
), | |
) | |
with gr.Blocks(title=TITLE) as demo: | |
gr.Markdown(_INTRO) | |
gr.HTML( | |
"""<span style="color:red" font-size:smaller>⚠️ DISCLAIMER: the images displayed by this tool were generated by text-to-image systems and may depict offensive stereotypes or contain explicit content.</span>""" | |
) | |
with gr.Tab("Professions Overview"): | |
gr.Markdown( | |
""" | |
Select one or more professions and models from the dropdowns on the left to see which clusters are most representative for this combination. | |
Try choosing different numbers of clusters to see if the results change, and then go to the 'Profession Focus' tab to go more in-depth into these results. | |
The `Labor Women` column provided for comparison corresponds to the gender ratio reported by the | |
[U.S. Bureau of Labor Statistics](https://www.bls.gov/cps/cpsaat11.htm) for each profession. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("Select the parameters here:") | |
num_clusters = gr.Radio( | |
[12, 24, 48], | |
value=12, | |
label="How many clusters do you want to use to represent identities?", | |
) | |
model_choices = gr.Dropdown( | |
[ | |
"All Models", | |
"Stable Diffusion 1.4", | |
"Stable Diffusion 2", | |
"Dall-E 2", | |
], | |
value="All Models", | |
label="Which models do you want to compare?", | |
interactive=True, | |
) | |
profession_choices_overview = gr.Dropdown( | |
professions, | |
value=[ | |
"all professions", | |
"CEO", | |
"director", | |
"social assistant", | |
"social worker", | |
], | |
label="Which professions do you want to compare?", | |
multiselect=True, | |
interactive=True, | |
) | |
with gr.Column(scale=3): | |
with gr.Row(): | |
table = gr.HTML( | |
label="Profession assignment per cluster", wrap=True | |
) | |
with gr.Row(): | |
# clusters = gr.Dataframe(type="array", visible=False, col_count=1) | |
clusters = gr.Textbox(label="clusters", visible=False) | |
gr.Markdown( | |
""" | |
##### What do the clusters mean? | |
Below is a summary of the identity cluster compositions. | |
For more details, see the [companion demo](https://huggingface.co/spaces/tti-bias/diffusion-face-clustering): | |
""" | |
) | |
with gr.Row(): | |
with gr.Accordion(label="Cluster summaries", open=True): | |
cluster_descriptions_table = gr.Text( | |
"TODO", label="Cluster summaries", show_label=False | |
) | |
with gr.Tab("Profession Focus"): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown( | |
"Select a profession to visualize and see which clusters and identity groups are most represented in the profession, as well as some examples of generated images below." | |
) | |
profession_choice_focus = gr.Dropdown( | |
choices=professions, | |
value="scientist", | |
label="Select profession:", | |
) | |
num_clusters_focus = gr.Radio( | |
[12, 24, 48], | |
value=12, | |
label="How many clusters do you want to use to represent identities?", | |
) | |
with gr.Column(): | |
plot = gr.Plot( | |
label=f"Makeup of the cluster assignments for profession {profession_choice_focus}" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown( | |
""" | |
##### What do the clusters mean? | |
Below is a summary of the identity cluster compositions. | |
For more details, see the [companion demo](https://huggingface.co/spaces/tti-bias/DiffusionFaceClustering): | |
""" | |
) | |
with gr.Accordion(label="Cluster summaries", open=True): | |
cluster_descriptions = gr.Text( | |
"TODO", label="Cluster summaries", show_label=False | |
) | |
with gr.Column(): | |
gr.Markdown( | |
""" | |
##### What's in the clusters? | |
You can show examples of profession images assigned to each identity cluster by selecting one here: | |
""" | |
) | |
with gr.Accordion(label="Cluster selection", open=True): | |
cluster_id_focus = gr.Dropdown( | |
choices=[i for i in range(num_clusters_focus.value)], | |
value=0, | |
label="Select cluster to visualize:", | |
) | |
with gr.Row(): | |
examplars_plot = gr.Gallery( | |
label="Profession images assigned to the selected cluster." | |
).style(grid=4, height="auto", container=True) | |
demo.load( | |
make_profession_table, | |
[num_clusters, profession_choices_overview, model_choices], | |
[clusters, table, cluster_descriptions_table], | |
queue=False, | |
) | |
demo.load( | |
make_profession_plot, | |
[num_clusters_focus, profession_choice_focus], | |
[plot, cluster_id_focus, cluster_descriptions], | |
queue=False, | |
) | |
demo.load( | |
show_examplars, | |
[ | |
num_clusters_focus, | |
profession_choice_focus, | |
cluster_id_focus, | |
], | |
[examplars_plot, examplars_plot], | |
queue=False, | |
) | |
for var in [num_clusters, model_choices, profession_choices_overview]: | |
var.change( | |
make_profession_table, | |
[num_clusters, profession_choices_overview, model_choices], | |
[clusters, table, cluster_descriptions_table], | |
queue=False, | |
) | |
for var in [num_clusters_focus, profession_choice_focus]: | |
var.change( | |
make_profession_plot, | |
[num_clusters_focus, profession_choice_focus], | |
[plot, cluster_id_focus, cluster_descriptions], | |
queue=False, | |
) | |
for var in [num_clusters_focus, profession_choice_focus, cluster_id_focus]: | |
var.change( | |
show_examplars, | |
[ | |
num_clusters_focus, | |
profession_choice_focus, | |
cluster_id_focus, | |
], | |
[examplars_plot, examplars_plot], | |
queue=False, | |
) | |
if __name__ == "__main__": | |
demo.queue().launch(debug=True) | |