|
import spaces
|
|
import gradio as gr
|
|
import torch
|
|
import torchvision
|
|
from PIL import Image
|
|
import numpy as np
|
|
import os
|
|
import shutil
|
|
from huggingface_hub import HfApi, HfFolder
|
|
from transformers import AutoModelForImageClassification, Trainer, TrainingArguments
|
|
from datasets import load_dataset, Dataset
|
|
from sklearn.metrics import accuracy_score
|
|
|
|
|
|
@spaces.GPU
|
|
def dummy_gpu():
|
|
pass
|
|
|
|
|
|
HF_MODEL = "google/vit-base-patch16-224"
|
|
HF_DATASET = "verytuffcat/recaptcha-dataset"
|
|
HF_REPO = ""
|
|
HF_TOKEN = os.getenv("HF_TOKEN", "")
|
|
if os.getenv("HF_REPO"): HF_REPO = os.getenv("HF_REPO")
|
|
if os.getenv("HF_DATASET"): HF_DATASET = os.getenv("HF_DATASET")
|
|
if os.getenv("HF_MODEL"): HF_MODEL = os.getenv("HF_MODEL")
|
|
OUT_DIR = "./new_model"
|
|
|
|
|
|
def pil_to_torch(image: Image.Image):
|
|
return torchvision.transforms.functional.to_tensor(image.convert("RGB").resize((224, 224), Image.BICUBIC))
|
|
|
|
|
|
def compute_metrics(eval_pred):
|
|
predictions, labels = eval_pred
|
|
predictions = np.argmax(predictions, axis=1)
|
|
metrics = dict(accuracy=accuracy_score(predictions, labels))
|
|
return metrics
|
|
|
|
|
|
def collate_fn(batch):
|
|
pixel_values = torch.stack([pil_to_torch(x["image"]) for x in batch])
|
|
labels = torch.tensor([x["label"] for x in batch])
|
|
return {"pixel_values": pixel_values, "labels": labels}
|
|
|
|
|
|
def train(model_id: str, dataset_id: str, repo_id: str, hf_token: str, slice: int, log_md: str, progress=gr.Progress(track_tqdm=True)):
|
|
try:
|
|
if not model_id or not dataset_id or not repo_id: raise gr.Error("Fill fields.")
|
|
if not hf_token: hf_token = HF_TOKEN
|
|
if not hf_token: raise gr.Error("Input HF token.")
|
|
HfFolder.save_token(hf_token)
|
|
api = HfApi(token=hf_token)
|
|
|
|
if slice >= 1: dataset = load_dataset(dataset_id, split=f"train[1:{int(slice)+1}]", num_proc=8)
|
|
else: dataset = load_dataset(dataset_id, split="train", num_proc=8)
|
|
labels = dataset.features["label"].names
|
|
label2id, id2label = dict(), dict()
|
|
for i, label in enumerate(labels):
|
|
label2id[label] = i
|
|
id2label[i] = label
|
|
|
|
model = AutoModelForImageClassification.from_pretrained(model_id, label2id=label2id, id2label=id2label, ignore_mismatched_sizes=True)
|
|
|
|
training_args = TrainingArguments(
|
|
output_dir=OUT_DIR,
|
|
use_cpu=True,
|
|
no_cuda=True,
|
|
fp16=True,
|
|
optim="adamw_torch",
|
|
lr_scheduler_type="linear",
|
|
learning_rate=0.00005,
|
|
per_device_train_batch_size=8,
|
|
num_train_epochs=3,
|
|
gradient_accumulation_steps=1,
|
|
use_ipex=True,
|
|
|
|
eval_strategy="no",
|
|
logging_strategy="epoch",
|
|
remove_unused_columns=False,
|
|
push_to_hub=False,
|
|
save_total_limit=2,
|
|
report_to="none"
|
|
)
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=training_args,
|
|
data_collator=collate_fn,
|
|
compute_metrics=compute_metrics,
|
|
train_dataset=dataset,
|
|
|
|
eval_dataset=None,
|
|
)
|
|
train_results = trainer.train()
|
|
trainer.save_model(OUT_DIR)
|
|
trainer.log_metrics("train", train_results.metrics)
|
|
trainer.save_metrics("train", train_results.metrics)
|
|
trainer.save_state()
|
|
|
|
api.create_repo(repo_id=repo_id, private=True, exist_ok=True, token=hf_token)
|
|
api.upload_folder(repo_id=repo_id, folder_path=OUT_DIR, path_in_repo="", token=HF_TOKEN)
|
|
shutil.rmtree(OUT_DIR)
|
|
|
|
return log_md
|
|
except Exception as e:
|
|
raise gr.Error(f"Error occured: {e}")
|
|
|
|
|
|
with gr.Blocks() as demo:
|
|
with gr.Row():
|
|
model_id = gr.Textbox(label="Source model", value=HF_MODEL, lines=1)
|
|
dataset_id = gr.Textbox(label="Source dataset", value=HF_DATASET, lines=1)
|
|
with gr.Row():
|
|
repo_id = gr.Textbox(label="Output repo", value=HF_REPO, lines=1)
|
|
hf_token = gr.Textbox(label="HF write token", value="", lines=1)
|
|
with gr.Accordion("Advanced", open=False):
|
|
slice = gr.Number(label="Slice dataset", info="If 0, use whole dataset", minimum=0, maximum=999999, step=1, value=0)
|
|
train_btn = gr.Button("Train")
|
|
log_md = gr.Markdown(label="Log", value="<br><br>")
|
|
|
|
train_btn.click(train, [model_id, dataset_id, repo_id, hf_token, slice, log_md], [log_md])
|
|
|
|
|
|
demo.queue().launch() |