Spaces:
Running
on
Zero
Running
on
Zero
Commit
Β·
3bb5a93
1
Parent(s):
3dcef48
ad toxicity check
Browse files
app.py
CHANGED
@@ -6,7 +6,6 @@ import multiprocessing
|
|
6 |
import gradio as gr
|
7 |
import pandas as pd
|
8 |
import polars as pl
|
9 |
-
import numpy as np
|
10 |
import matplotlib.pyplot as plt
|
11 |
import spaces
|
12 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
@@ -90,12 +89,83 @@ def plot_and_df(texts, preds):
|
|
90 |
)
|
91 |
|
92 |
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
def run_quality_check(dataset, column, batch_size, num_examples):
|
95 |
-
# config = "default"
|
96 |
info_resp = session.get(f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=3).json()
|
97 |
if "error" in info_resp:
|
98 |
-
yield "β " + info_resp["error"], gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), plt.Figure()
|
99 |
return
|
100 |
config = "default" if "default" in info_resp["dataset_info"] else next(iter(info_resp["dataset_info"]))
|
101 |
split = "train" if "train" in info_resp["dataset_info"][config]["splits"] else next(
|
@@ -106,9 +176,10 @@ def run_quality_check(dataset, column, batch_size, num_examples):
|
|
106 |
try:
|
107 |
data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/partial-{split}/0000.parquet", columns=[column])
|
108 |
except Exception as error:
|
109 |
-
yield f"β {error}", gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), plt.Figure()
|
110 |
return
|
111 |
texts = data[column].to_list()
|
|
|
112 |
# batch_size = 100
|
113 |
predictions, texts_processed = [], []
|
114 |
num_examples = min(len(texts), num_examples)
|
@@ -117,7 +188,7 @@ def run_quality_check(dataset, column, batch_size, num_examples):
|
|
117 |
batch_predictions = predict(batch_texts)
|
118 |
predictions.extend(batch_predictions)
|
119 |
texts_processed.extend(batch_texts)
|
120 |
-
yield {"check in progress...": (i+batch_size) / num_examples}, *plot_and_df(texts_processed, predictions), plt.Figure()
|
121 |
|
122 |
with multiprocessing.Pool(processes=8) as pool:
|
123 |
props = pool.map(proportion_non_ascii, texts)
|
@@ -128,7 +199,8 @@ def run_quality_check(dataset, column, batch_size, num_examples):
|
|
128 |
plt.xlabel('Proportion of non-ASCII characters')
|
129 |
plt.ylabel('Number of texts')
|
130 |
|
131 |
-
yield {"finished": 1.}, *plot_and_df(texts_processed, predictions), plt.gcf()
|
|
|
132 |
|
133 |
with gr.Blocks() as demo:
|
134 |
gr.Markdown(
|
@@ -175,6 +247,13 @@ with gr.Blocks() as demo:
|
|
175 |
|
176 |
# non_ascii_hist = gr.DataFrame(visible=False)
|
177 |
non_ascii_hist = gr.Plot()
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
demo.launch()
|
|
|
6 |
import gradio as gr
|
7 |
import pandas as pd
|
8 |
import polars as pl
|
|
|
9 |
import matplotlib.pyplot as plt
|
10 |
import spaces
|
11 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
|
|
89 |
)
|
90 |
|
91 |
|
92 |
+
PERSPECTIVE_API_KEY = os.environ.get("PERSPECTIVE_API_KEY")
|
93 |
+
PERSPECTIVE_URL = f"https://commentanalyzer.googleapis.com/v1alpha1/comments:analyze?key={PERSPECTIVE_API_KEY}"
|
94 |
+
REQUESTED_ATTRIBUTES = {"TOXICITY": {}, "SEVERE_TOXICITY": {},
|
95 |
+
"IDENTITY_ATTACK": {}, "INSULT": {}, "PROFANITY": {},
|
96 |
+
"THREAT": {}}
|
97 |
+
ATT_SCORE = "attributeScores"
|
98 |
+
SUM_SCORE = "summaryScore"
|
99 |
+
|
100 |
+
|
101 |
+
def plot_toxicity(scores):
|
102 |
+
fig, axs = plt.subplots(2, 3)#, figsize=(10, 6))
|
103 |
+
for x, y, score_name in zip([0,0,0,1,1,1], [0,1,2,0,1,2], scores):
|
104 |
+
axs[x,y].hist(scores[score_name], bins=20, range=(0., 1.))
|
105 |
+
# axs[x,y].set_title(f'Histogram of {score_name}')
|
106 |
+
axs[x,y].set_xlabel(f'{score_name}')
|
107 |
+
# axs[x,y].set_ylabel('Number of texts')
|
108 |
+
fig.supylabel("Number of texts")
|
109 |
+
fig.suptitle("Histogram of toxicity scores")
|
110 |
+
fig.tight_layout()
|
111 |
+
|
112 |
+
return fig
|
113 |
+
|
114 |
+
def call_perspective_api(texts_df, column_name):#, s):
|
115 |
+
headers = {
|
116 |
+
"content-type": "application/json",
|
117 |
+
}
|
118 |
+
req_att_scores = {attr: [] for attr in REQUESTED_ATTRIBUTES}
|
119 |
+
|
120 |
+
texts = texts_df[column_name].values
|
121 |
+
for i, text in tqdm(enumerate(texts), desc="scanning with perspective"):
|
122 |
+
data = {
|
123 |
+
"comment": {"text": text},
|
124 |
+
"languages": ["en"],
|
125 |
+
"requestedAttributes": REQUESTED_ATTRIBUTES
|
126 |
+
}
|
127 |
+
time.sleep(1)
|
128 |
+
try:
|
129 |
+
req_response = requests.post(PERSPECTIVE_URL, json=data, headers=headers)
|
130 |
+
except Exception as e:
|
131 |
+
print(e)
|
132 |
+
return req_att_scores
|
133 |
+
|
134 |
+
if req_response.ok:
|
135 |
+
response = req_response.json()
|
136 |
+
# logger.info("Perspective API response is:")
|
137 |
+
# logger.info(response)
|
138 |
+
if ATT_SCORE in response:
|
139 |
+
for req_att in REQUESTED_ATTRIBUTES:
|
140 |
+
if req_att in response[ATT_SCORE]:
|
141 |
+
att_score = response[ATT_SCORE][req_att][SUM_SCORE]["value"]
|
142 |
+
req_att_scores[req_att].append(att_score)
|
143 |
+
else:
|
144 |
+
req_att_scores[req_att].append(0)
|
145 |
+
else:
|
146 |
+
# logger.error(
|
147 |
+
# "Unexpected response format from Perspective API."
|
148 |
+
# )
|
149 |
+
raise ValueError(req_response)
|
150 |
+
else:
|
151 |
+
try:
|
152 |
+
req_response.raise_for_status()
|
153 |
+
except Exception as e:
|
154 |
+
print(e)
|
155 |
+
return req_att_scores
|
156 |
+
if i % 10 == 0:
|
157 |
+
plot_toxicity(req_att_scores)
|
158 |
+
yield plt.gcf(), pd.DataFrame()
|
159 |
+
|
160 |
+
plot_toxicity(req_att_scores)
|
161 |
+
yield plt.gcf(), pd.DataFrame.from_dict({column_name: texts, **req_att_scores})
|
162 |
+
|
163 |
+
|
164 |
+
# @spaces.GPU
|
165 |
def run_quality_check(dataset, column, batch_size, num_examples):
|
|
|
166 |
info_resp = session.get(f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=3).json()
|
167 |
if "error" in info_resp:
|
168 |
+
yield "β " + info_resp["error"], gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), plt.Figure(), pd.DataFrame(),
|
169 |
return
|
170 |
config = "default" if "default" in info_resp["dataset_info"] else next(iter(info_resp["dataset_info"]))
|
171 |
split = "train" if "train" in info_resp["dataset_info"][config]["splits"] else next(
|
|
|
176 |
try:
|
177 |
data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/partial-{split}/0000.parquet", columns=[column])
|
178 |
except Exception as error:
|
179 |
+
yield f"β {error}", gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), plt.Figure(), pd.DataFrame(),
|
180 |
return
|
181 |
texts = data[column].to_list()
|
182 |
+
texts_sample = data.sample(20, shuffle=True, seed=16).to_pandas()
|
183 |
# batch_size = 100
|
184 |
predictions, texts_processed = [], []
|
185 |
num_examples = min(len(texts), num_examples)
|
|
|
188 |
batch_predictions = predict(batch_texts)
|
189 |
predictions.extend(batch_predictions)
|
190 |
texts_processed.extend(batch_texts)
|
191 |
+
yield {"check in progress...": (i+batch_size) / num_examples}, *plot_and_df(texts_processed, predictions), plt.Figure(), pd.DataFrame()
|
192 |
|
193 |
with multiprocessing.Pool(processes=8) as pool:
|
194 |
props = pool.map(proportion_non_ascii, texts)
|
|
|
199 |
plt.xlabel('Proportion of non-ASCII characters')
|
200 |
plt.ylabel('Number of texts')
|
201 |
|
202 |
+
yield {"finished": 1.}, *plot_and_df(texts_processed, predictions), plt.gcf(), texts_sample
|
203 |
+
|
204 |
|
205 |
with gr.Blocks() as demo:
|
206 |
gr.Markdown(
|
|
|
247 |
|
248 |
# non_ascii_hist = gr.DataFrame(visible=False)
|
249 |
non_ascii_hist = gr.Plot()
|
250 |
+
texts_sample_df = gr.DataFrame(visible=False)
|
251 |
+
gr_check_btn.click(run_quality_check, inputs=[dataset_name, text_column, batch_size, num_examples], outputs=[progress_bar, plot, df_low, df_medium, df_high, non_ascii_hist, texts_sample_df])
|
252 |
+
|
253 |
+
gr_toxicity_btn = gr.Button("Run perpspective API to check toxicity of random samples.")
|
254 |
+
toxicity_hist = gr.Plot()
|
255 |
+
with gr.Accordion("Explore examples with toxicity scores:", open=False):
|
256 |
+
toxicity_df = gr.DataFrame()
|
257 |
+
gr_toxicity_btn.click(call_perspective_api, inputs=[texts_sample_df, text_column], outputs=[toxicity_hist, toxicity_df])
|
258 |
|
259 |
demo.launch()
|