Spaces:
Sleeping
Sleeping
from pydrive2.auth import GoogleAuth | |
from pydrive2.drive import GoogleDrive | |
import os | |
import gradio as gr | |
from datasets import load_dataset, Dataset | |
import pandas as pd | |
from PIL import Image | |
from tqdm import tqdm | |
import logging | |
import yaml | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load settings | |
with open('settings.yaml', 'r') as file: | |
settings = yaml.safe_load(file) | |
class DatasetManager: | |
def __init__(self, local_images_dir="downloaded_cards"): | |
self.local_images_dir = local_images_dir | |
self.drive = None | |
self.dataset_name = "GotThatData/sports-cards" | |
# Create local directory if it doesn't exist | |
os.makedirs(local_images_dir, exist_ok=True) | |
def authenticate_drive(self): | |
"""Authenticate with Google Drive""" | |
try: | |
gauth = GoogleAuth() | |
# Use the settings from yaml file | |
gauth.settings['client_config_file'] = settings['client_secrets_file'] | |
# Try to load saved credentials | |
gauth.LoadCredentialsFile("credentials.txt") | |
if gauth.credentials is None: | |
# Authenticate if no credentials found | |
gauth.LocalWebserverAuth() | |
elif gauth.access_token_expired: | |
# Refresh them if expired | |
gauth.Refresh() | |
else: | |
# Initialize the saved credentials | |
gauth.Authorize() | |
# Save the credentials for future use | |
gauth.SaveCredentialsFile("credentials.txt") | |
self.drive = GoogleDrive(gauth) | |
return True, "Successfully authenticated with Google Drive" | |
except Exception as e: | |
return False, f"Authentication failed: {str(e)}" | |
def download_and_rename_files(self, drive_folder_id, naming_convention): | |
"""Download files from Google Drive and rename them""" | |
if not self.drive: | |
return False, "Google Drive not authenticated", [] | |
try: | |
# List files in the folder | |
query = f"'{drive_folder_id}' in parents and trashed=false" | |
file_list = self.drive.ListFile({'q': query}).GetList() | |
if not file_list: | |
# Try to get single file if folder is empty | |
file = self.drive.CreateFile({'id': drive_folder_id}) | |
if file: | |
file_list = [file] | |
else: | |
return False, "No files found with the specified ID", [] | |
renamed_files = [] | |
existing_dataset = None | |
try: | |
existing_dataset = load_dataset(self.dataset_name) | |
logger.info(f"Loaded existing dataset: {self.dataset_name}") | |
start_index = len(existing_dataset['train']) if 'train' in existing_dataset else 0 | |
except Exception as e: | |
logger.info(f"No existing dataset found, starting fresh: {str(e)}") | |
start_index = 0 | |
for i, file in enumerate(tqdm(file_list, desc="Downloading files")): | |
if file['mimeType'].startswith('image/'): | |
new_filename = f"{naming_convention}_{start_index + i + 1}.jpg" | |
file_path = os.path.join(self.local_images_dir, new_filename) | |
# Download file | |
file.GetContentFile(file_path) | |
# Verify the image can be opened | |
try: | |
with Image.open(file_path) as img: | |
img.verify() | |
renamed_files.append({ | |
'file_path': file_path, | |
'original_name': file['title'], | |
'new_name': new_filename, | |
'image': file_path | |
}) | |
except Exception as e: | |
logger.error(f"Error processing image {file['title']}: {str(e)}") | |
if os.path.exists(file_path): | |
os.remove(file_path) | |
return True, f"Successfully processed {len(renamed_files)} images", renamed_files | |
except Exception as e: | |
return False, f"Error downloading files: {str(e)}", [] | |
def update_huggingface_dataset(self, renamed_files): | |
"""Update the sports-cards dataset with new images""" | |
try: | |
# Create a DataFrame with the file information | |
df = pd.DataFrame(renamed_files) | |
# Create a Hugging Face Dataset | |
new_dataset = Dataset.from_pandas(df) | |
try: | |
# Try to load existing dataset | |
existing_dataset = load_dataset(self.dataset_name) | |
# Concatenate with existing dataset if it exists | |
if 'train' in existing_dataset: | |
new_dataset = concatenate_datasets([existing_dataset['train'], new_dataset]) | |
except Exception: | |
logger.info("Creating new dataset") | |
# Push to Hugging Face Hub | |
new_dataset.push_to_hub(self.dataset_name, split="train") | |
return True, f"Successfully updated dataset '{self.dataset_name}' with {len(renamed_files)} new images" | |
except Exception as e: | |
return False, f"Error updating Hugging Face dataset: {str(e)}" | |
def process_pipeline(folder_id, naming_convention): | |
"""Main pipeline to process images and update dataset""" | |
manager = DatasetManager() | |
# Step 1: Authenticate | |
auth_success, auth_message = manager.authenticate_drive() | |
if not auth_success: | |
return auth_message | |
# Step 2: Download and rename files | |
success, message, renamed_files = manager.download_and_rename_files(folder_id, naming_convention) | |
if not success: | |
return message | |
# Step 3: Update Hugging Face dataset | |
success, hf_message = manager.update_huggingface_dataset(renamed_files) | |
return f"{message}\n{hf_message}" | |
# Gradio interface | |
demo = gr.Interface( | |
fn=process_pipeline, | |
inputs=[ | |
gr.Textbox( | |
label="Google Drive File/Folder ID", | |
placeholder="Enter the ID from your Google Drive URL", | |
value="151VOxPO91mg0C3ORiioGUd4hogzP1ujm" | |
), | |
gr.Textbox( | |
label="Naming Convention", | |
placeholder="e.g., sports_card", | |
value="sports_card" | |
) | |
], | |
outputs=gr.Textbox(label="Status"), | |
title="Sports Cards Dataset Processor", | |
description="Download card images from Google Drive and add them to the sports-cards dataset" | |
) | |
if __name__ == "__main__": | |
demo.launch() |