John6666 commited on
Commit
e869a77
1 Parent(s): bd0de8c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -17
app.py CHANGED
@@ -5,9 +5,10 @@ import torchvision
5
  from PIL import Image
6
  import numpy as np
7
  import os
8
- from huggingface_hub import HfApi, HfFolder, Repository
9
- from transformers import ViTForImageClassification, Trainer, TrainingArguments
10
- from datasets import load_dataset
 
11
  from sklearn.metrics import accuracy_score
12
 
13
 
@@ -19,7 +20,6 @@ def dummy_gpu():
19
  HF_MODEL = "google/vit-base-patch16-224"
20
  HF_DATASET = "verytuffcat/recaptcha-dataset"
21
  HF_REPO = ""
22
-
23
  HF_TOKEN = os.getenv("HF_TOKEN", "")
24
  if os.getenv("HF_REPO"): HF_REPO = os.getenv("HF_REPO")
25
  if os.getenv("HF_DATASET"): HF_DATASET = os.getenv("HF_DATASET")
@@ -27,32 +27,45 @@ if os.getenv("HF_MODEL"): HF_MODEL = os.getenv("HF_MODEL")
27
  OUT_DIR = "./new_model"
28
 
29
 
 
 
 
 
30
  def compute_metrics(eval_pred):
31
  predictions, labels = eval_pred
32
  predictions = np.argmax(predictions, axis=1)
33
- return dict(accuracy=accuracy_score(predictions, labels))
 
34
 
35
 
36
  def collate_fn(batch):
37
- pixel_values = torch.stack([torchvision.transforms.functional.to_tensor(x["image"].convert("RGB").resize((224, 224), Image.BICUBIC)) for x in batch])
38
  labels = torch.tensor([x["label"] for x in batch])
39
  return {"pixel_values": pixel_values, "labels": labels}
40
 
41
 
42
- def train(model_id: str, dataset_id: str, repo_id: str, hf_token: str, log_md: str, progress=gr.Progress(track_tqdm=True)):
43
  try:
44
  if not model_id or not dataset_id or not repo_id: raise gr.Error("Fill fields.")
45
  if not hf_token: hf_token = HF_TOKEN
46
  if not hf_token: raise gr.Error("Input HF token.")
47
  HfFolder.save_token(hf_token)
 
 
 
 
 
 
 
 
 
48
 
49
- model = ViTForImageClassification.from_pretrained(model_id)
50
- dataset = load_dataset(dataset_id, split="train")
51
 
52
  training_args = TrainingArguments(
53
  output_dir=OUT_DIR,
54
  use_cpu=True,
55
- no_cuda=True,
56
  fp16=True,
57
  optim="adamw_torch",
58
  lr_scheduler_type="linear",
@@ -61,8 +74,9 @@ def train(model_id: str, dataset_id: str, repo_id: str, hf_token: str, log_md: s
61
  num_train_epochs=3,
62
  gradient_accumulation_steps=1,
63
  use_ipex=True,
 
64
  eval_strategy="no",
65
- logging_strategy="no",
66
  remove_unused_columns=False,
67
  push_to_hub=False,
68
  save_total_limit=2,
@@ -74,15 +88,18 @@ def train(model_id: str, dataset_id: str, repo_id: str, hf_token: str, log_md: s
74
  data_collator=collate_fn,
75
  compute_metrics=compute_metrics,
76
  train_dataset=dataset,
 
77
  eval_dataset=None,
78
  )
79
- trainer.train()
80
  trainer.save_model(OUT_DIR)
 
 
 
81
 
82
- api = HfApi(token=hf_token)
83
- api.create_repo(repo_id=repo_id, private=True, token=hf_token)
84
- repo = Repository(local_dir=OUT_DIR, clone_from=repo_id, use_auth_token=hf_token)
85
- repo.push_to_hub()
86
 
87
  return log_md
88
  except Exception as e:
@@ -96,10 +113,12 @@ with gr.Blocks() as demo:
96
  with gr.Row():
97
  repo_id = gr.Textbox(label="Output repo", value=HF_REPO, lines=1)
98
  hf_token = gr.Textbox(label="HF write token", value="", lines=1)
 
 
99
  train_btn = gr.Button("Train")
100
  log_md = gr.Markdown(label="Log", value="<br><br>")
101
 
102
- train_btn.click(train, [model_id, dataset_id, repo_id, hf_token, log_md], [log_md])
103
 
104
 
105
  demo.queue().launch()
 
5
  from PIL import Image
6
  import numpy as np
7
  import os
8
+ import shutil
9
+ from huggingface_hub import HfApi, HfFolder
10
+ from transformers import AutoModelForImageClassification, Trainer, TrainingArguments
11
+ from datasets import load_dataset, Dataset
12
  from sklearn.metrics import accuracy_score
13
 
14
 
 
20
  HF_MODEL = "google/vit-base-patch16-224"
21
  HF_DATASET = "verytuffcat/recaptcha-dataset"
22
  HF_REPO = ""
 
23
  HF_TOKEN = os.getenv("HF_TOKEN", "")
24
  if os.getenv("HF_REPO"): HF_REPO = os.getenv("HF_REPO")
25
  if os.getenv("HF_DATASET"): HF_DATASET = os.getenv("HF_DATASET")
 
27
  OUT_DIR = "./new_model"
28
 
29
 
30
+ def pil_to_torch(image: Image.Image):
31
+ return torchvision.transforms.functional.to_tensor(image.convert("RGB").resize((224, 224), Image.BICUBIC))
32
+
33
+
34
  def compute_metrics(eval_pred):
35
  predictions, labels = eval_pred
36
  predictions = np.argmax(predictions, axis=1)
37
+ metrics = dict(accuracy=accuracy_score(predictions, labels))
38
+ return metrics
39
 
40
 
41
  def collate_fn(batch):
42
+ pixel_values = torch.stack([pil_to_torch(x["image"]) for x in batch])
43
  labels = torch.tensor([x["label"] for x in batch])
44
  return {"pixel_values": pixel_values, "labels": labels}
45
 
46
 
47
+ def train(model_id: str, dataset_id: str, repo_id: str, hf_token: str, slice: int, log_md: str, progress=gr.Progress(track_tqdm=True)):
48
  try:
49
  if not model_id or not dataset_id or not repo_id: raise gr.Error("Fill fields.")
50
  if not hf_token: hf_token = HF_TOKEN
51
  if not hf_token: raise gr.Error("Input HF token.")
52
  HfFolder.save_token(hf_token)
53
+ api = HfApi(token=hf_token)
54
+
55
+ if slice >= 1: dataset = load_dataset(dataset_id, split=f"train[1:{int(slice)+1}]", num_proc=8)
56
+ else: dataset = load_dataset(dataset_id, split="train", num_proc=8)
57
+ labels = dataset.features["label"].names
58
+ label2id, id2label = dict(), dict()
59
+ for i, label in enumerate(labels):
60
+ label2id[label] = i
61
+ id2label[i] = label
62
 
63
+ model = AutoModelForImageClassification.from_pretrained(model_id, label2id=label2id, id2label=id2label, ignore_mismatched_sizes=True)
 
64
 
65
  training_args = TrainingArguments(
66
  output_dir=OUT_DIR,
67
  use_cpu=True,
68
+ no_cuda=True, #
69
  fp16=True,
70
  optim="adamw_torch",
71
  lr_scheduler_type="linear",
 
74
  num_train_epochs=3,
75
  gradient_accumulation_steps=1,
76
  use_ipex=True,
77
+ #eval_strategy="epoch",
78
  eval_strategy="no",
79
+ logging_strategy="epoch",
80
  remove_unused_columns=False,
81
  push_to_hub=False,
82
  save_total_limit=2,
 
88
  data_collator=collate_fn,
89
  compute_metrics=compute_metrics,
90
  train_dataset=dataset,
91
+ #eval_dataset=dataset,
92
  eval_dataset=None,
93
  )
94
+ train_results = trainer.train()
95
  trainer.save_model(OUT_DIR)
96
+ trainer.log_metrics("train", train_results.metrics)
97
+ trainer.save_metrics("train", train_results.metrics)
98
+ trainer.save_state()
99
 
100
+ api.create_repo(repo_id=repo_id, private=True, exist_ok=True, token=hf_token)
101
+ api.upload_folder(repo_id=repo_id, folder_path=OUT_DIR, path_in_repo="", token=HF_TOKEN)
102
+ shutil.rmtree(OUT_DIR)
 
103
 
104
  return log_md
105
  except Exception as e:
 
113
  with gr.Row():
114
  repo_id = gr.Textbox(label="Output repo", value=HF_REPO, lines=1)
115
  hf_token = gr.Textbox(label="HF write token", value="", lines=1)
116
+ with gr.Accordion("Advanced", open=False):
117
+ slice = gr.Number(label="Slice dataset", info="If 0, use whole dataset", minimum=0, maximum=999999, step=1, value=0)
118
  train_btn = gr.Button("Train")
119
  log_md = gr.Markdown(label="Log", value="<br><br>")
120
 
121
+ train_btn.click(train, [model_id, dataset_id, repo_id, hf_token, slice, log_md], [log_md])
122
 
123
 
124
  demo.queue().launch()