Spaces:
Sleeping
Sleeping
import json | |
import tempfile | |
import zipfile | |
from datetime import datetime | |
from pathlib import Path | |
from uuid import uuid4 | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
from huggingface_hub import CommitScheduler, InferenceClient | |
IMAGE_DATASET_DIR = Path("image_dataset_1M") / f"train-{uuid4()}" | |
IMAGE_DATASET_DIR.mkdir(parents=True, exist_ok=True) | |
IMAGE_JSONL_PATH = IMAGE_DATASET_DIR / "metadata.jsonl" | |
class ZipScheduler(CommitScheduler): | |
""" | |
Example of a custom CommitScheduler with overwritten `push_to_hub` to zip images before pushing them to the Hub. | |
Workflow: | |
1. Read metadata + list PNG files. | |
2. Zip png files in a single archive. | |
3. Create commit (metadata + archive). | |
4. Delete local png files to avoid re-uploading them later. | |
Only step 1 requires to activate the lock. Once the metadata is read, the lock is released and the rest of the | |
process can be done without blocking the Gradio app. | |
""" | |
def push_to_hub(self): | |
# 1. Read metadata + list PNG files | |
with self.lock: | |
png_files = list(self.folder_path.glob("*.png")) | |
if len(png_files) == 0: | |
return None # return early if nothing to commit | |
# Read and delete metadata file | |
metadata = IMAGE_JSONL_PATH.read_text() | |
try: | |
IMAGE_JSONL_PATH.unlink() | |
except Exception: | |
pass | |
with tempfile.TemporaryDirectory() as tmpdir: | |
# 2. Zip png files + metadata in a single archive | |
archive_path = Path(tmpdir) / "train.zip" | |
with zipfile.ZipFile(archive_path, "w", zipfile.ZIP_DEFLATED) as zip: | |
# PNG files | |
for png_file in png_files: | |
zip.write(filename=png_file, arcname=png_file.name) | |
# Metadata | |
tmp_metadata = Path(tmpdir) / "metadata.jsonl" | |
tmp_metadata.write_text(metadata) | |
zip.write(filename=tmp_metadata, arcname="metadata.jsonl") | |
# 3. Create commit | |
self.api.upload_file( | |
repo_id=self.repo_id, | |
repo_type=self.repo_type, | |
revision=self.revision, | |
path_in_repo=f"train-{uuid4()}.zip", | |
path_or_fileobj=archive_path, | |
) | |
# 4. Delete local png files to avoid re-uploading them later | |
for png_file in png_files: | |
try: | |
png_file.unlink() | |
except Exception: | |
pass | |
scheduler = ZipScheduler( | |
repo_id="example-space-to-dataset-image-zip", | |
repo_type="dataset", | |
folder_path=IMAGE_DATASET_DIR, | |
) | |
client = InferenceClient() | |
def generate_image(prompt: str) -> Image: | |
return client.text_to_image(prompt) | |
def save_image(prompt: str, image_array: np.ndarray) -> None: | |
print("Saving: " + prompt) | |
image_path = IMAGE_DATASET_DIR / f"{uuid4()}.png" | |
with scheduler.lock: | |
Image.fromarray(image_array).save(image_path) | |
with IMAGE_JSONL_PATH.open("a") as f: | |
json.dump({"prompt": prompt, "file_name": image_path.name, "datetime": datetime.now().isoformat()}, f) | |
f.write("\n") | |
def get_demo(): | |
with gr.Row(): | |
prompt_value = gr.Textbox(label="Prompt") | |
image_value = gr.Image(label="Generated image") | |
text_to_image_btn = gr.Button("Generate") | |
text_to_image_btn.click(fn=generate_image, inputs=prompt_value, outputs=image_value).success( | |
fn=save_image, | |
inputs=[prompt_value, image_value], | |
outputs=None, | |
) | |