Spaces:
Sleeping
Sleeping
# Taken from https://huggingface.co/spaces/hysts-samples/save-user-preferences | |
# Credits to @@hysts | |
import datetime | |
import json | |
import shutil | |
import tempfile | |
import uuid | |
from pathlib import Path | |
from typing import Any, Dict, List, Optional, Union | |
import gradio as gr | |
import pyarrow as pa | |
import pyarrow.parquet as pq | |
from gradio_client import Client | |
from huggingface_hub import CommitScheduler | |
from huggingface_hub.hf_api import HfApi | |
####################### | |
# Parquet scheduler # | |
# Run in scheduler.py # | |
####################### | |
class ParquetScheduler(CommitScheduler): | |
""" | |
Usage: configure the scheduler with a repo id. Once started, you can add data to be uploaded to the Hub. 1 `.append` | |
call will result in 1 row in your final dataset. | |
```py | |
# Start scheduler | |
>>> scheduler = ParquetScheduler(repo_id="my-parquet-dataset") | |
# Append some data to be uploaded | |
>>> scheduler.append({...}) | |
>>> scheduler.append({...}) | |
>>> scheduler.append({...}) | |
``` | |
The scheduler will automatically infer the schema from the data it pushes. | |
Optionally, you can manually set the schema yourself: | |
```py | |
>>> scheduler = ParquetScheduler( | |
... repo_id="my-parquet-dataset", | |
... schema={ | |
... "prompt": {"_type": "Value", "dtype": "string"}, | |
... "negative_prompt": {"_type": "Value", "dtype": "string"}, | |
... "guidance_scale": {"_type": "Value", "dtype": "int64"}, | |
... "image": {"_type": "Image"}, | |
... }, | |
... ) | |
See https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value for the list of | |
possible values. | |
""" | |
def __init__( | |
self, | |
*, | |
repo_id: str, | |
schema: Optional[Dict[str, Dict[str, str]]] = None, | |
every: Union[int, float] = 5, | |
path_in_repo: Optional[str] = "data", | |
repo_type: Optional[str] = "dataset", | |
revision: Optional[str] = None, | |
private: bool = False, | |
token: Optional[str] = None, | |
allow_patterns: Union[List[str], str, None] = None, | |
ignore_patterns: Union[List[str], str, None] = None, | |
hf_api: Optional[HfApi] = None, | |
) -> None: | |
super().__init__( | |
repo_id=repo_id, | |
folder_path="dummy", # not used by the scheduler | |
every=every, | |
path_in_repo=path_in_repo, | |
repo_type=repo_type, | |
revision=revision, | |
private=private, | |
token=token, | |
allow_patterns=allow_patterns, | |
ignore_patterns=ignore_patterns, | |
hf_api=hf_api, | |
) | |
self._rows: List[Dict[str, Any]] = [] | |
self._schema = schema | |
def append(self, row: Dict[str, Any]) -> None: | |
"""Add a new item to be uploaded.""" | |
with self.lock: | |
self._rows.append(row) | |
def push_to_hub(self): | |
# Check for new rows to push | |
with self.lock: | |
rows = self._rows | |
self._rows = [] | |
if not rows: | |
return | |
print(f"Got {len(rows)} item(s) to commit.") | |
# Load images + create 'features' config for datasets library | |
schema: Dict[str, Dict] = self._schema or {} | |
path_to_cleanup: List[Path] = [] | |
for row in rows: | |
for key, value in row.items(): | |
# Infer schema (for `datasets` library) | |
if key not in schema: | |
schema[key] = _infer_schema(key, value) | |
# Load binary files if necessary | |
if schema[key]["_type"] in ("Image", "Audio"): | |
# It's an image or audio: we load the bytes and remember to cleanup the file | |
file_path = Path(value) | |
if file_path.is_file(): | |
row[key] = { | |
"path": file_path.name, | |
"bytes": file_path.read_bytes(), | |
} | |
path_to_cleanup.append(file_path) | |
# Complete rows if needed | |
for row in rows: | |
for feature in schema: | |
if feature not in row: | |
row[feature] = None | |
# Export items to Arrow format | |
table = pa.Table.from_pylist(rows) | |
# Add metadata (used by datasets library) | |
table = table.replace_schema_metadata( | |
{"huggingface": json.dumps({"info": {"features": schema}})} | |
) | |
# Write to parquet file | |
archive_file = tempfile.NamedTemporaryFile() | |
pq.write_table(table, archive_file.name) | |
# Upload | |
self.api.upload_file( | |
repo_id=self.repo_id, | |
repo_type=self.repo_type, | |
revision=self.revision, | |
path_in_repo=f"{uuid.uuid4()}.parquet", | |
path_or_fileobj=archive_file.name, | |
) | |
print(f"Commit completed.") | |
# Cleanup | |
archive_file.close() | |
for path in path_to_cleanup: | |
path.unlink(missing_ok=True) | |
def _infer_schema(key: str, value: Any) -> Dict[str, str]: | |
""" | |
Infer schema for the `datasets` library. | |
See https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value. | |
""" | |
if "image" in key: | |
return {"_type": "Image"} | |
if "audio" in key: | |
return {"_type": "Audio"} | |
if isinstance(value, int): | |
return {"_type": "Value", "dtype": "int64"} | |
if isinstance(value, float): | |
return {"_type": "Value", "dtype": "float64"} | |
if isinstance(value, bool): | |
return {"_type": "Value", "dtype": "bool"} | |
if isinstance(value, bytes): | |
return {"_type": "Value", "dtype": "binary"} | |
# Otherwise in last resort => convert it to a string | |
return {"_type": "Value", "dtype": "string"} | |
################# | |
# Gradio app # | |
# Run in app.py # | |
################# | |
PARQUET_DATASET_DIR = Path("parquet_dataset") | |
PARQUET_DATASET_DIR.mkdir(parents=True, exist_ok=True) | |
scheduler = ParquetScheduler(repo_id="example-space-to-dataset-parquet") | |
# client = Client("stabilityai/stable-diffusion") # Space is paused | |
client = Client("runwayml/stable-diffusion-v1-5") | |
def generate(prompt: str) -> tuple[str, list[str]]: | |
"""Generate images on 'submit' button.""" | |
# Generate from https://huggingface.co/spaces/stabilityai/stable-diffusion | |
# out_dir = client.predict(prompt, "", 9, fn_index=1) # Space 'stabilityai/stable-diffusion' is paused | |
out_dir = client.predict(prompt, fn_index=1) | |
with (Path(out_dir) / "captions.json").open() as f: | |
paths = list(json.load(f).keys()) | |
# Save config used to generate data | |
with tempfile.NamedTemporaryFile( | |
mode="w", suffix=".json", delete=False | |
) as config_file: | |
json.dump( | |
{"prompt": prompt, "negative_prompt": "", "guidance_scale": 9}, config_file | |
) | |
return config_file.name, paths | |
def get_selected_index(evt: gr.SelectData) -> int: | |
"""Select "best" image.""" | |
return evt.index | |
def save_preference( | |
config_path: str, gallery: list[dict[str, Any]], selected_index: int | |
) -> None: | |
"""Save preference, i.e. move images to a new folder and send paths+config to scheduler.""" | |
save_dir = PARQUET_DATASET_DIR / f"{uuid.uuid4()}" | |
save_dir.mkdir(parents=True, exist_ok=True) | |
# Load config | |
with open(config_path) as f: | |
data = json.load(f) | |
# Add selected item + timestamp | |
data["selected_index"] = selected_index | |
data["timestamp"] = datetime.datetime.utcnow().isoformat() | |
# Copy and add images | |
for index, path in enumerate(x["name"] for x in gallery): | |
name = f"{index:03d}" | |
dst_path = save_dir / f"{name}{Path(path).suffix}" | |
shutil.move(path, dst_path) | |
data[f"image_{name}"] = dst_path | |
# Send to scheduler | |
scheduler.append(data) | |
def clear() -> tuple[dict, dict, dict]: | |
"""Clear all values once saved.""" | |
return (gr.update(value=None), gr.update(value=None), gr.update(interactive=False)) | |
def get_demo(): | |
with gr.Group(): | |
prompt = gr.Text(show_label=False, placeholder="Prompt") | |
config_path = gr.Text(visible=False) | |
gallery = gr.Gallery(show_label=False).style( | |
columns=2, rows=2, height="600px", object_fit="scale-down" | |
) | |
selected_index = gr.Number(visible=False, precision=0) | |
save_preference_button = gr.Button("Save preference", interactive=False) | |
# Generate images on submit | |
prompt.submit(fn=generate, inputs=prompt, outputs=[config_path, gallery],).success( | |
fn=lambda: gr.update(interactive=True), | |
outputs=save_preference_button, | |
queue=False, | |
) | |
# Save preference on click | |
gallery.select( | |
fn=get_selected_index, | |
outputs=selected_index, | |
queue=False, | |
) | |
save_preference_button.click( | |
fn=save_preference, | |
inputs=[config_path, gallery, selected_index], | |
queue=False, | |
).then( | |
fn=clear, | |
outputs=[config_path, gallery, save_preference_button], | |
queue=False, | |
) | |