trtest1 / app.py
John6666's picture
Upload app.py
e869a77 verified
import spaces
import gradio as gr
import torch
import torchvision
from PIL import Image
import numpy as np
import os
import shutil
from huggingface_hub import HfApi, HfFolder
from transformers import AutoModelForImageClassification, Trainer, TrainingArguments
from datasets import load_dataset, Dataset
from sklearn.metrics import accuracy_score
@spaces.GPU
def dummy_gpu():
pass
HF_MODEL = "google/vit-base-patch16-224"
HF_DATASET = "verytuffcat/recaptcha-dataset"
HF_REPO = ""
HF_TOKEN = os.getenv("HF_TOKEN", "")
if os.getenv("HF_REPO"): HF_REPO = os.getenv("HF_REPO")
if os.getenv("HF_DATASET"): HF_DATASET = os.getenv("HF_DATASET")
if os.getenv("HF_MODEL"): HF_MODEL = os.getenv("HF_MODEL")
OUT_DIR = "./new_model"
def pil_to_torch(image: Image.Image):
return torchvision.transforms.functional.to_tensor(image.convert("RGB").resize((224, 224), Image.BICUBIC))
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
metrics = dict(accuracy=accuracy_score(predictions, labels))
return metrics
def collate_fn(batch):
pixel_values = torch.stack([pil_to_torch(x["image"]) for x in batch])
labels = torch.tensor([x["label"] for x in batch])
return {"pixel_values": pixel_values, "labels": labels}
def train(model_id: str, dataset_id: str, repo_id: str, hf_token: str, slice: int, log_md: str, progress=gr.Progress(track_tqdm=True)):
try:
if not model_id or not dataset_id or not repo_id: raise gr.Error("Fill fields.")
if not hf_token: hf_token = HF_TOKEN
if not hf_token: raise gr.Error("Input HF token.")
HfFolder.save_token(hf_token)
api = HfApi(token=hf_token)
if slice >= 1: dataset = load_dataset(dataset_id, split=f"train[1:{int(slice)+1}]", num_proc=8)
else: dataset = load_dataset(dataset_id, split="train", num_proc=8)
labels = dataset.features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = i
id2label[i] = label
model = AutoModelForImageClassification.from_pretrained(model_id, label2id=label2id, id2label=id2label, ignore_mismatched_sizes=True)
training_args = TrainingArguments(
output_dir=OUT_DIR,
use_cpu=True,
no_cuda=True, #
fp16=True,
optim="adamw_torch",
lr_scheduler_type="linear",
learning_rate=0.00005,
per_device_train_batch_size=8,
num_train_epochs=3,
gradient_accumulation_steps=1,
use_ipex=True,
#eval_strategy="epoch",
eval_strategy="no",
logging_strategy="epoch",
remove_unused_columns=False,
push_to_hub=False,
save_total_limit=2,
report_to="none"
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=collate_fn,
compute_metrics=compute_metrics,
train_dataset=dataset,
#eval_dataset=dataset,
eval_dataset=None,
)
train_results = trainer.train()
trainer.save_model(OUT_DIR)
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()
api.create_repo(repo_id=repo_id, private=True, exist_ok=True, token=hf_token)
api.upload_folder(repo_id=repo_id, folder_path=OUT_DIR, path_in_repo="", token=HF_TOKEN)
shutil.rmtree(OUT_DIR)
return log_md
except Exception as e:
raise gr.Error(f"Error occured: {e}")
with gr.Blocks() as demo:
with gr.Row():
model_id = gr.Textbox(label="Source model", value=HF_MODEL, lines=1)
dataset_id = gr.Textbox(label="Source dataset", value=HF_DATASET, lines=1)
with gr.Row():
repo_id = gr.Textbox(label="Output repo", value=HF_REPO, lines=1)
hf_token = gr.Textbox(label="HF write token", value="", lines=1)
with gr.Accordion("Advanced", open=False):
slice = gr.Number(label="Slice dataset", info="If 0, use whole dataset", minimum=0, maximum=999999, step=1, value=0)
train_btn = gr.Button("Train")
log_md = gr.Markdown(label="Log", value="<br><br>")
train_btn.click(train, [model_id, dataset_id, repo_id, hf_token, slice, log_md], [log_md])
demo.queue().launch()