# Taken from https://huggingface.co/spaces/hysts-samples/save-user-preferences # Credits to @@hysts import datetime import json import shutil import tempfile import uuid from pathlib import Path from typing import Any, Dict, List import gradio as gr import pyarrow as pa import pyarrow.parquet as pq from gradio_client import Client from huggingface_hub import CommitScheduler ####################### # Parquet scheduler # # Run in scheduler.py # ####################### class ParquetScheduler(CommitScheduler): def append(self, row: Dict[str, Any]) -> None: with self.lock: if not hasattr(self, "rows") or self.rows is None: self.rows = [] self.rows.append(row) def set_schema(self, schema: Dict[str, Dict[str, str]]) -> None: """ Define a schema to help `datasets` load the generated library. This method is optional and can be called once just after the scheduler had been created. If it is not called, the schema is automatically inferred before pushing the data to the Hub. See https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value for the list of possible values. Example: ```py scheduler.set_schema({ "prompt": {"_type": "Value", "dtype": "string"}, "negative_prompt": {"_type": "Value", "dtype": "string"}, "guidance_scale": {"_type": "Value", "dtype": "int64"}, "image": {"_type": "Image"}, }) ``` """ self._schema = schema def push_to_hub(self): # Check for new rows to push with self.lock: rows = getattr(self, "rows", None) self.rows = None if not rows: return print(f"Got {len(rows)} item(s) to commit.") # Load images + create 'features' config for datasets library hf_features: Dict[str, Dict] = getattr(self, "_schema", None) or {} path_to_cleanup: List[Path] = [] for row in rows: for key, value in row.items(): # Infer schema (for `datasets` library) if key not in hf_features: hf_features[key] = _infer_schema(key, value) # Load binary files if necessary if hf_features[key]["_type"] in ("Image", "Audio"): # It's an image or audio: we load the bytes and remember to cleanup the file file_path = Path(value) if file_path.is_file(): row[key] = { "path": file_path.name, "bytes": file_path.read_bytes(), } path_to_cleanup.append(file_path) # Complete rows if needed for row in rows: for feature in hf_features: if feature not in row: row[feature] = None # Export items to Arrow format table = pa.Table.from_pylist(rows) # Add metadata (used by datasets library) table = table.replace_schema_metadata( {"huggingface": json.dumps({"info": {"features": hf_features}})} ) # Write to parquet file archive_file = tempfile.NamedTemporaryFile() pq.write_table(table, archive_file.name) # Upload self.api.upload_file( repo_id=self.repo_id, repo_type=self.repo_type, revision=self.revision, path_in_repo=f"{uuid.uuid4()}.parquet", path_or_fileobj=archive_file.name, ) print(f"Commit completed.") # Cleanup archive_file.close() for path in path_to_cleanup: path.unlink(missing_ok=True) def _infer_schema(key: str, value: Any) -> Dict[str, str]: """ Infer schema for the `datasets` library. See https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value. """ if "image" in key: return {"_type": "Image"} if "audio" in key: return {"_type": "Audio"} if isinstance(value, int): return {"_type": "Value", "dtype": "int64"} if isinstance(value, float): return {"_type": "Value", "dtype": "float64"} if isinstance(value, bool): return {"_type": "Value", "dtype": "bool"} if isinstance(value, bytes): return {"_type": "Value", "dtype": "binary"} # Otherwise in last resort => convert it to a string return {"_type": "Value", "dtype": "string"} ################# # Gradio app # # Run in app.py # ################# PARQUET_DATASET_DIR = Path("parquet_dataset") PARQUET_DATASET_DIR.mkdir(parents=True, exist_ok=True) scheduler = ParquetScheduler( repo_id="example-space-to-dataset-parquet", repo_type="dataset", folder_path=PARQUET_DATASET_DIR, path_in_repo="data", ) client = Client("stabilityai/stable-diffusion") def generate(prompt: str) -> tuple[str, list[str]]: """Generate images on 'submit' button.""" # Generate from https://huggingface.co/spaces/stabilityai/stable-diffusion out_dir = client.predict(prompt, "", 9, fn_index=1) with (Path(out_dir) / "captions.json").open() as f: paths = list(json.load(f).keys()) # Save config used to generate data with tempfile.NamedTemporaryFile( mode="w", suffix=".json", delete=False ) as config_file: json.dump( {"prompt": prompt, "negative_prompt": "", "guidance_scale": 9}, config_file ) return config_file.name, paths def get_selected_index(evt: gr.SelectData) -> int: """Select "best" image.""" return evt.index def save_preference( config_path: str, gallery: list[dict[str, Any]], selected_index: int ) -> None: """Save preference, i.e. move images to a new folder and send paths+config to scheduler.""" save_dir = PARQUET_DATASET_DIR / f"{uuid.uuid4()}" save_dir.mkdir(parents=True, exist_ok=True) # Load config with open(config_path) as f: data = json.load(f) # Add selected item + timestamp data["selected_index"] = selected_index data["timestamp"] = datetime.datetime.utcnow().isoformat() # Copy and add images for index, path in enumerate(x["name"] for x in gallery): name = f"{index:03d}" dst_path = save_dir / f"{name}{Path(path).suffix}" shutil.move(path, dst_path) data[f"image_{name}"] = dst_path # Send to scheduler scheduler.append(data) def clear() -> tuple[dict, dict, dict]: """Clear all values once saved.""" return (gr.update(value=None), gr.update(value=None), gr.update(interactive=False)) def get_demo(): with gr.Group(): prompt = gr.Text(show_label=False, placeholder="Prompt") config_path = gr.Text(visible=False) gallery = gr.Gallery(show_label=False).style( columns=2, rows=2, height="600px", object_fit="scale-down" ) selected_index = gr.Number(visible=False, precision=0) save_preference_button = gr.Button("Save preference", interactive=False) # Generate images on submit prompt.submit(fn=generate, inputs=prompt, outputs=[config_path, gallery],).success( fn=lambda: gr.update(interactive=True), outputs=save_preference_button, queue=False, ) # Save preference on click gallery.select( fn=get_selected_index, outputs=selected_index, queue=False, ) save_preference_button.click( fn=save_preference, inputs=[config_path, gallery, selected_index], queue=False, ).then( fn=clear, outputs=[config_path, gallery, save_preference_button], queue=False, )