Succesfully uploaded model to HF hub in correct place
Browse files- src/config.py +5 -0
- src/trainer.py +29 -10
src/config.py
CHANGED
@@ -8,12 +8,17 @@ MAX_DOWNLOAD_TIME = 0.2
|
|
8 |
IMAGE_DOWNLOAD_PATH = pathlib.Path("./data/images")
|
9 |
WANDB_LOG_PATH = pathlib.Path("/tmp/wandb_logs")
|
10 |
MODEL_PATH = pathlib.Path("/tmp/models")
|
|
|
|
|
11 |
|
12 |
IMAGE_DOWNLOAD_PATH.mkdir(parents=True, exist_ok=True)
|
13 |
WANDB_LOG_PATH.mkdir(parents=True, exist_ok=True)
|
14 |
MODEL_PATH.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
15 |
|
16 |
MODEL_NAME = "tiny_clip"
|
|
|
17 |
|
18 |
WANDB_ENTITY = "sachinruk"
|
19 |
|
|
|
8 |
IMAGE_DOWNLOAD_PATH = pathlib.Path("./data/images")
|
9 |
WANDB_LOG_PATH = pathlib.Path("/tmp/wandb_logs")
|
10 |
MODEL_PATH = pathlib.Path("/tmp/models")
|
11 |
+
VISION_MODEL_PATH = MODEL_PATH / "vision"
|
12 |
+
TEXT_MODEL_PATH = MODEL_PATH / "text"
|
13 |
|
14 |
IMAGE_DOWNLOAD_PATH.mkdir(parents=True, exist_ok=True)
|
15 |
WANDB_LOG_PATH.mkdir(parents=True, exist_ok=True)
|
16 |
MODEL_PATH.mkdir(parents=True, exist_ok=True)
|
17 |
+
VISION_MODEL_PATH.mkdir(parents=True, exist_ok=True)
|
18 |
+
TEXT_MODEL_PATH.mkdir(parents=True, exist_ok=True)
|
19 |
|
20 |
MODEL_NAME = "tiny_clip"
|
21 |
+
REPO_ID = "sachin/clip-model"
|
22 |
|
23 |
WANDB_ENTITY = "sachinruk"
|
24 |
|
src/trainer.py
CHANGED
@@ -1,5 +1,8 @@
|
|
1 |
import os
|
2 |
|
|
|
|
|
|
|
3 |
from src import config
|
4 |
from src import data
|
5 |
from src import loss
|
@@ -11,23 +14,39 @@ from src.lightning_module import LightningModule
|
|
11 |
|
12 |
|
13 |
def _upload_model_to_hub(
|
14 |
-
vision_encoder: models.TinyCLIPVisionEncoder,
|
|
|
|
|
15 |
):
|
16 |
vision_encoder.save_pretrained(
|
17 |
-
str(config.
|
18 |
-
variant="vision_encoder",
|
19 |
safe_serialization=True,
|
20 |
-
push_to_hub=True,
|
21 |
-
repo_id="debug-clip-model",
|
22 |
)
|
23 |
text_encoder.save_pretrained(
|
24 |
-
str(config.
|
25 |
-
variant="text_encoder",
|
26 |
safe_serialization=True,
|
27 |
-
push_to_hub=True,
|
28 |
-
repo_id="debug-clip-model",
|
29 |
)
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
def train(trainer_config: config.TrainerConfig):
|
33 |
if "HF_TOKEN" not in os.environ:
|
@@ -51,7 +70,7 @@ def train(trainer_config: config.TrainerConfig):
|
|
51 |
trainer = utils.get_trainer(trainer_config)
|
52 |
trainer.fit(lightning_module, train_dl, valid_dl)
|
53 |
|
54 |
-
_upload_model_to_hub(vision_encoder, text_encoder)
|
55 |
|
56 |
|
57 |
if __name__ == "__main__":
|
|
|
1 |
import os
|
2 |
|
3 |
+
from huggingface_hub import HfApi
|
4 |
+
from loguru import logger
|
5 |
+
|
6 |
from src import config
|
7 |
from src import data
|
8 |
from src import loss
|
|
|
14 |
|
15 |
|
16 |
def _upload_model_to_hub(
|
17 |
+
vision_encoder: models.TinyCLIPVisionEncoder,
|
18 |
+
text_encoder: models.TinyCLIPTextEncoder,
|
19 |
+
debug: bool = False,
|
20 |
):
|
21 |
vision_encoder.save_pretrained(
|
22 |
+
str(config.VISION_MODEL_PATH),
|
|
|
23 |
safe_serialization=True,
|
|
|
|
|
24 |
)
|
25 |
text_encoder.save_pretrained(
|
26 |
+
str(config.TEXT_MODEL_PATH),
|
|
|
27 |
safe_serialization=True,
|
|
|
|
|
28 |
)
|
29 |
|
30 |
+
api = HfApi()
|
31 |
+
if debug:
|
32 |
+
repo_components = config.REPO_ID.split("/", maxsplit=1)
|
33 |
+
repo_components[1] = f"debug-{repo_components[1]}"
|
34 |
+
repo_id = "/".join(repo_components)
|
35 |
+
else:
|
36 |
+
repo_id = config.REPO_ID
|
37 |
+
common_hf_api_params = {
|
38 |
+
"repo_id": repo_id,
|
39 |
+
"repo_type": "model",
|
40 |
+
}
|
41 |
+
if not api.repo_exists(**common_hf_api_params):
|
42 |
+
logger.info(f"Creating repo {repo_id} on Hugging Face Hub.")
|
43 |
+
api.create_repo(**common_hf_api_params) # type: ignore
|
44 |
+
logger.info(f"Uploading models in {str(config.MODEL_PATH)} to {repo_id}.")
|
45 |
+
api.upload_folder(
|
46 |
+
folder_path=config.MODEL_PATH,
|
47 |
+
**common_hf_api_params, # type: ignore
|
48 |
+
) # type: ignore
|
49 |
+
|
50 |
|
51 |
def train(trainer_config: config.TrainerConfig):
|
52 |
if "HF_TOKEN" not in os.environ:
|
|
|
70 |
trainer = utils.get_trainer(trainer_config)
|
71 |
trainer.fit(lightning_module, train_dl, valid_dl)
|
72 |
|
73 |
+
_upload_model_to_hub(vision_encoder, text_encoder, trainer_config.debug)
|
74 |
|
75 |
|
76 |
if __name__ == "__main__":
|