Upload app.py
Browse files
app.py
CHANGED
@@ -5,9 +5,10 @@ import torchvision
|
|
5 |
from PIL import Image
|
6 |
import numpy as np
|
7 |
import os
|
8 |
-
|
9 |
-
from
|
10 |
-
from
|
|
|
11 |
from sklearn.metrics import accuracy_score
|
12 |
|
13 |
|
@@ -19,7 +20,6 @@ def dummy_gpu():
|
|
19 |
HF_MODEL = "google/vit-base-patch16-224"
|
20 |
HF_DATASET = "verytuffcat/recaptcha-dataset"
|
21 |
HF_REPO = ""
|
22 |
-
|
23 |
HF_TOKEN = os.getenv("HF_TOKEN", "")
|
24 |
if os.getenv("HF_REPO"): HF_REPO = os.getenv("HF_REPO")
|
25 |
if os.getenv("HF_DATASET"): HF_DATASET = os.getenv("HF_DATASET")
|
@@ -27,32 +27,45 @@ if os.getenv("HF_MODEL"): HF_MODEL = os.getenv("HF_MODEL")
|
|
27 |
OUT_DIR = "./new_model"
|
28 |
|
29 |
|
|
|
|
|
|
|
|
|
30 |
def compute_metrics(eval_pred):
|
31 |
predictions, labels = eval_pred
|
32 |
predictions = np.argmax(predictions, axis=1)
|
33 |
-
|
|
|
34 |
|
35 |
|
36 |
def collate_fn(batch):
|
37 |
-
pixel_values = torch.stack([
|
38 |
labels = torch.tensor([x["label"] for x in batch])
|
39 |
return {"pixel_values": pixel_values, "labels": labels}
|
40 |
|
41 |
|
42 |
-
def train(model_id: str, dataset_id: str, repo_id: str, hf_token: str, log_md: str, progress=gr.Progress(track_tqdm=True)):
|
43 |
try:
|
44 |
if not model_id or not dataset_id or not repo_id: raise gr.Error("Fill fields.")
|
45 |
if not hf_token: hf_token = HF_TOKEN
|
46 |
if not hf_token: raise gr.Error("Input HF token.")
|
47 |
HfFolder.save_token(hf_token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
model =
|
50 |
-
dataset = load_dataset(dataset_id, split="train")
|
51 |
|
52 |
training_args = TrainingArguments(
|
53 |
output_dir=OUT_DIR,
|
54 |
use_cpu=True,
|
55 |
-
no_cuda=True,
|
56 |
fp16=True,
|
57 |
optim="adamw_torch",
|
58 |
lr_scheduler_type="linear",
|
@@ -61,8 +74,9 @@ def train(model_id: str, dataset_id: str, repo_id: str, hf_token: str, log_md: s
|
|
61 |
num_train_epochs=3,
|
62 |
gradient_accumulation_steps=1,
|
63 |
use_ipex=True,
|
|
|
64 |
eval_strategy="no",
|
65 |
-
logging_strategy="
|
66 |
remove_unused_columns=False,
|
67 |
push_to_hub=False,
|
68 |
save_total_limit=2,
|
@@ -74,15 +88,18 @@ def train(model_id: str, dataset_id: str, repo_id: str, hf_token: str, log_md: s
|
|
74 |
data_collator=collate_fn,
|
75 |
compute_metrics=compute_metrics,
|
76 |
train_dataset=dataset,
|
|
|
77 |
eval_dataset=None,
|
78 |
)
|
79 |
-
trainer.train()
|
80 |
trainer.save_model(OUT_DIR)
|
|
|
|
|
|
|
81 |
|
82 |
-
api =
|
83 |
-
api.
|
84 |
-
|
85 |
-
repo.push_to_hub()
|
86 |
|
87 |
return log_md
|
88 |
except Exception as e:
|
@@ -96,10 +113,12 @@ with gr.Blocks() as demo:
|
|
96 |
with gr.Row():
|
97 |
repo_id = gr.Textbox(label="Output repo", value=HF_REPO, lines=1)
|
98 |
hf_token = gr.Textbox(label="HF write token", value="", lines=1)
|
|
|
|
|
99 |
train_btn = gr.Button("Train")
|
100 |
log_md = gr.Markdown(label="Log", value="<br><br>")
|
101 |
|
102 |
-
train_btn.click(train, [model_id, dataset_id, repo_id, hf_token, log_md], [log_md])
|
103 |
|
104 |
|
105 |
demo.queue().launch()
|
|
|
5 |
from PIL import Image
|
6 |
import numpy as np
|
7 |
import os
|
8 |
+
import shutil
|
9 |
+
from huggingface_hub import HfApi, HfFolder
|
10 |
+
from transformers import AutoModelForImageClassification, Trainer, TrainingArguments
|
11 |
+
from datasets import load_dataset, Dataset
|
12 |
from sklearn.metrics import accuracy_score
|
13 |
|
14 |
|
|
|
20 |
HF_MODEL = "google/vit-base-patch16-224"
|
21 |
HF_DATASET = "verytuffcat/recaptcha-dataset"
|
22 |
HF_REPO = ""
|
|
|
23 |
HF_TOKEN = os.getenv("HF_TOKEN", "")
|
24 |
if os.getenv("HF_REPO"): HF_REPO = os.getenv("HF_REPO")
|
25 |
if os.getenv("HF_DATASET"): HF_DATASET = os.getenv("HF_DATASET")
|
|
|
27 |
OUT_DIR = "./new_model"
|
28 |
|
29 |
|
30 |
+
def pil_to_torch(image: Image.Image):
|
31 |
+
return torchvision.transforms.functional.to_tensor(image.convert("RGB").resize((224, 224), Image.BICUBIC))
|
32 |
+
|
33 |
+
|
34 |
def compute_metrics(eval_pred):
|
35 |
predictions, labels = eval_pred
|
36 |
predictions = np.argmax(predictions, axis=1)
|
37 |
+
metrics = dict(accuracy=accuracy_score(predictions, labels))
|
38 |
+
return metrics
|
39 |
|
40 |
|
41 |
def collate_fn(batch):
|
42 |
+
pixel_values = torch.stack([pil_to_torch(x["image"]) for x in batch])
|
43 |
labels = torch.tensor([x["label"] for x in batch])
|
44 |
return {"pixel_values": pixel_values, "labels": labels}
|
45 |
|
46 |
|
47 |
+
def train(model_id: str, dataset_id: str, repo_id: str, hf_token: str, slice: int, log_md: str, progress=gr.Progress(track_tqdm=True)):
|
48 |
try:
|
49 |
if not model_id or not dataset_id or not repo_id: raise gr.Error("Fill fields.")
|
50 |
if not hf_token: hf_token = HF_TOKEN
|
51 |
if not hf_token: raise gr.Error("Input HF token.")
|
52 |
HfFolder.save_token(hf_token)
|
53 |
+
api = HfApi(token=hf_token)
|
54 |
+
|
55 |
+
if slice >= 1: dataset = load_dataset(dataset_id, split=f"train[1:{int(slice)+1}]", num_proc=8)
|
56 |
+
else: dataset = load_dataset(dataset_id, split="train", num_proc=8)
|
57 |
+
labels = dataset.features["label"].names
|
58 |
+
label2id, id2label = dict(), dict()
|
59 |
+
for i, label in enumerate(labels):
|
60 |
+
label2id[label] = i
|
61 |
+
id2label[i] = label
|
62 |
|
63 |
+
model = AutoModelForImageClassification.from_pretrained(model_id, label2id=label2id, id2label=id2label, ignore_mismatched_sizes=True)
|
|
|
64 |
|
65 |
training_args = TrainingArguments(
|
66 |
output_dir=OUT_DIR,
|
67 |
use_cpu=True,
|
68 |
+
no_cuda=True, #
|
69 |
fp16=True,
|
70 |
optim="adamw_torch",
|
71 |
lr_scheduler_type="linear",
|
|
|
74 |
num_train_epochs=3,
|
75 |
gradient_accumulation_steps=1,
|
76 |
use_ipex=True,
|
77 |
+
#eval_strategy="epoch",
|
78 |
eval_strategy="no",
|
79 |
+
logging_strategy="epoch",
|
80 |
remove_unused_columns=False,
|
81 |
push_to_hub=False,
|
82 |
save_total_limit=2,
|
|
|
88 |
data_collator=collate_fn,
|
89 |
compute_metrics=compute_metrics,
|
90 |
train_dataset=dataset,
|
91 |
+
#eval_dataset=dataset,
|
92 |
eval_dataset=None,
|
93 |
)
|
94 |
+
train_results = trainer.train()
|
95 |
trainer.save_model(OUT_DIR)
|
96 |
+
trainer.log_metrics("train", train_results.metrics)
|
97 |
+
trainer.save_metrics("train", train_results.metrics)
|
98 |
+
trainer.save_state()
|
99 |
|
100 |
+
api.create_repo(repo_id=repo_id, private=True, exist_ok=True, token=hf_token)
|
101 |
+
api.upload_folder(repo_id=repo_id, folder_path=OUT_DIR, path_in_repo="", token=HF_TOKEN)
|
102 |
+
shutil.rmtree(OUT_DIR)
|
|
|
103 |
|
104 |
return log_md
|
105 |
except Exception as e:
|
|
|
113 |
with gr.Row():
|
114 |
repo_id = gr.Textbox(label="Output repo", value=HF_REPO, lines=1)
|
115 |
hf_token = gr.Textbox(label="HF write token", value="", lines=1)
|
116 |
+
with gr.Accordion("Advanced", open=False):
|
117 |
+
slice = gr.Number(label="Slice dataset", info="If 0, use whole dataset", minimum=0, maximum=999999, step=1, value=0)
|
118 |
train_btn = gr.Button("Train")
|
119 |
log_md = gr.Markdown(label="Log", value="<br><br>")
|
120 |
|
121 |
+
train_btn.click(train, [model_id, dataset_id, repo_id, hf_token, slice, log_md], [log_md])
|
122 |
|
123 |
|
124 |
demo.queue().launch()
|