Spaces:
Sleeping
Sleeping
# Taken from https://huggingface.co/spaces/hysts-samples/save-user-preferences | |
# Credits to @@hysts | |
import datetime | |
import json | |
import shutil | |
import tempfile | |
import uuid | |
from pathlib import Path | |
from typing import Any, Dict, List | |
import gradio as gr | |
import pyarrow as pa | |
import pyarrow.parquet as pq | |
from gradio_client import Client | |
from huggingface_hub import CommitScheduler | |
####################### | |
# Parquet scheduler # | |
# Run in scheduler.py # | |
####################### | |
class ParquetScheduler(CommitScheduler): | |
def append(self, row: Dict[str, Any]) -> None: | |
with self.lock: | |
if not hasattr(self, "rows") or self.rows is None: | |
self.rows = [] | |
self.rows.append(row) | |
def set_schema(self, schema: Dict[str, Dict[str, str]]) -> None: | |
""" | |
Define a schema to help `datasets` load the generated library. | |
This method is optional and can be called once just after the scheduler had been created. If it is not called, | |
the schema is automatically inferred before pushing the data to the Hub. | |
See https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value for the list of | |
possible values. | |
Example: | |
```py | |
scheduler.set_schema({ | |
"prompt": {"_type": "Value", "dtype": "string"}, | |
"negative_prompt": {"_type": "Value", "dtype": "string"}, | |
"guidance_scale": {"_type": "Value", "dtype": "int64"}, | |
"image": {"_type": "Image"}, | |
}) | |
``` | |
""" | |
self._schema = schema | |
def push_to_hub(self): | |
# Check for new rows to push | |
with self.lock: | |
rows = getattr(self, "rows", None) | |
self.rows = None | |
if not rows: | |
return | |
print(f"Got {len(rows)} item(s) to commit.") | |
# Load images + create 'features' config for datasets library | |
hf_features: Dict[str, Dict] = getattr(self, "_schema", None) or {} | |
path_to_cleanup: List[Path] = [] | |
for row in rows: | |
for key, value in row.items(): | |
# Infer schema (for `datasets` library) | |
if key not in hf_features: | |
hf_features[key] = _infer_schema(key, value) | |
# Load binary files if necessary | |
if hf_features[key]["_type"] in ("Image", "Audio"): | |
# It's an image or audio: we load the bytes and remember to cleanup the file | |
file_path = Path(value) | |
if file_path.is_file(): | |
row[key] = { | |
"path": file_path.name, | |
"bytes": file_path.read_bytes(), | |
} | |
path_to_cleanup.append(file_path) | |
# Complete rows if needed | |
for row in rows: | |
for feature in hf_features: | |
if feature not in row: | |
row[feature] = None | |
# Export items to Arrow format | |
table = pa.Table.from_pylist(rows) | |
# Add metadata (used by datasets library) | |
table = table.replace_schema_metadata( | |
{"huggingface": json.dumps({"info": {"features": hf_features}})} | |
) | |
# Write to parquet file | |
archive_file = tempfile.NamedTemporaryFile() | |
pq.write_table(table, archive_file.name) | |
# Upload | |
self.api.upload_file( | |
repo_id=self.repo_id, | |
repo_type=self.repo_type, | |
revision=self.revision, | |
path_in_repo=f"{uuid.uuid4()}.parquet", | |
path_or_fileobj=archive_file.name, | |
) | |
print(f"Commit completed.") | |
# Cleanup | |
archive_file.close() | |
for path in path_to_cleanup: | |
path.unlink(missing_ok=True) | |
def _infer_schema(key: str, value: Any) -> Dict[str, str]: | |
""" | |
Infer schema for the `datasets` library. | |
See https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value. | |
""" | |
if "image" in key: | |
return {"_type": "Image"} | |
if "audio" in key: | |
return {"_type": "Audio"} | |
if isinstance(value, int): | |
return {"_type": "Value", "dtype": "int64"} | |
if isinstance(value, float): | |
return {"_type": "Value", "dtype": "float64"} | |
if isinstance(value, bool): | |
return {"_type": "Value", "dtype": "bool"} | |
if isinstance(value, bytes): | |
return {"_type": "Value", "dtype": "binary"} | |
# Otherwise in last resort => convert it to a string | |
return {"_type": "Value", "dtype": "string"} | |
################# | |
# Gradio app # | |
# Run in app.py # | |
################# | |
PARQUET_DATASET_DIR = Path("parquet_dataset") | |
PARQUET_DATASET_DIR.mkdir(parents=True, exist_ok=True) | |
scheduler = ParquetScheduler( | |
repo_id="example-space-to-dataset-parquet", | |
repo_type="dataset", | |
folder_path=PARQUET_DATASET_DIR, | |
path_in_repo="data", | |
) | |
client = Client("stabilityai/stable-diffusion") | |
def generate(prompt: str) -> tuple[str, list[str]]: | |
"""Generate images on 'submit' button.""" | |
# Generate from https://huggingface.co/spaces/stabilityai/stable-diffusion | |
out_dir = client.predict(prompt, "", 9, fn_index=1) | |
with (Path(out_dir) / "captions.json").open() as f: | |
paths = list(json.load(f).keys()) | |
# Save config used to generate data | |
with tempfile.NamedTemporaryFile( | |
mode="w", suffix=".json", delete=False | |
) as config_file: | |
json.dump( | |
{"prompt": prompt, "negative_prompt": "", "guidance_scale": 9}, config_file | |
) | |
return config_file.name, paths | |
def get_selected_index(evt: gr.SelectData) -> int: | |
"""Select "best" image.""" | |
return evt.index | |
def save_preference( | |
config_path: str, gallery: list[dict[str, Any]], selected_index: int | |
) -> None: | |
"""Save preference, i.e. move images to a new folder and send paths+config to scheduler.""" | |
save_dir = PARQUET_DATASET_DIR / f"{uuid.uuid4()}" | |
save_dir.mkdir(parents=True, exist_ok=True) | |
# Load config | |
with open(config_path) as f: | |
data = json.load(f) | |
# Add selected item + timestamp | |
data["selected_index"] = selected_index | |
data["timestamp"] = datetime.datetime.utcnow().isoformat() | |
# Copy and add images | |
for index, path in enumerate(x["name"] for x in gallery): | |
name = f"{index:03d}" | |
dst_path = save_dir / f"{name}{Path(path).suffix}" | |
shutil.move(path, dst_path) | |
data[f"image_{name}"] = dst_path | |
# Send to scheduler | |
scheduler.append(data) | |
def clear() -> tuple[dict, dict, dict]: | |
"""Clear all values once saved.""" | |
return (gr.update(value=None), gr.update(value=None), gr.update(interactive=False)) | |
def get_demo(): | |
with gr.Group(): | |
prompt = gr.Text(show_label=False, placeholder="Prompt") | |
config_path = gr.Text(visible=False) | |
gallery = gr.Gallery(show_label=False).style( | |
columns=2, rows=2, height="600px", object_fit="scale-down" | |
) | |
selected_index = gr.Number(visible=False, precision=0) | |
save_preference_button = gr.Button("Save preference", interactive=False) | |
# Generate images on submit | |
prompt.submit(fn=generate, inputs=prompt, outputs=[config_path, gallery],).success( | |
fn=lambda: gr.update(interactive=True), | |
outputs=save_preference_button, | |
queue=False, | |
) | |
# Save preference on click | |
gallery.select( | |
fn=get_selected_index, | |
outputs=selected_index, | |
queue=False, | |
) | |
save_preference_button.click( | |
fn=save_preference, | |
inputs=[config_path, gallery, selected_index], | |
queue=False, | |
).then( | |
fn=clear, | |
outputs=[config_path, gallery, save_preference_button], | |
queue=False, | |
) | |