Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Add script for fixing params number
Browse files
backend/app/services/models.py
CHANGED
@@ -454,11 +454,11 @@ class ModelService(HuggingFaceService):
|
|
454 |
if model_size is None:
|
455 |
logger.error(LogFormatter.error("Model size validation failed", error))
|
456 |
raise Exception(error)
|
457 |
-
logger.info(LogFormatter.success(f"Model size validation passed: {model_size:.1f}
|
458 |
|
459 |
# Size limits based on precision
|
460 |
if model_data["precision"] in ["float16", "bfloat16"] and model_size > 100:
|
461 |
-
error_msg = f"Model too large for {model_data['precision']} (limit:
|
462 |
logger.error(LogFormatter.error("Size limit exceeded", error_msg))
|
463 |
raise Exception(error_msg)
|
464 |
|
|
|
454 |
if model_size is None:
|
455 |
logger.error(LogFormatter.error("Model size validation failed", error))
|
456 |
raise Exception(error)
|
457 |
+
logger.info(LogFormatter.success(f"Model size validation passed: {model_size:.1f}B"))
|
458 |
|
459 |
# Size limits based on precision
|
460 |
if model_data["precision"] in ["float16", "bfloat16"] and model_size > 100:
|
461 |
+
error_msg = f"Model too large for {model_data['precision']} (limit: 100B)"
|
462 |
logger.error(LogFormatter.error("Size limit exceeded", error_msg))
|
463 |
raise Exception(error_msg)
|
464 |
|
backend/utils/fix_wrong_model_size.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import pytz
|
4 |
+
import logging
|
5 |
+
import asyncio
|
6 |
+
from datetime import datetime
|
7 |
+
from pathlib import Path
|
8 |
+
import huggingface_hub
|
9 |
+
from huggingface_hub.errors import RepositoryNotFoundError, RevisionNotFoundError
|
10 |
+
from dotenv import load_dotenv
|
11 |
+
from git import Repo
|
12 |
+
from datetime import datetime
|
13 |
+
from tqdm.auto import tqdm
|
14 |
+
from tqdm.contrib.logging import logging_redirect_tqdm
|
15 |
+
|
16 |
+
from app.config.hf_config import HF_TOKEN, QUEUE_REPO, API, EVAL_REQUESTS_PATH
|
17 |
+
|
18 |
+
from app.utils.model_validation import ModelValidator
|
19 |
+
|
20 |
+
huggingface_hub.logging.set_verbosity_error()
|
21 |
+
huggingface_hub.utils.disable_progress_bars()
|
22 |
+
|
23 |
+
logging.basicConfig(
|
24 |
+
level=logging.ERROR,
|
25 |
+
format='%(message)s'
|
26 |
+
)
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
load_dotenv()
|
29 |
+
|
30 |
+
validator = ModelValidator()
|
31 |
+
|
32 |
+
def get_changed_files(repo_path, start_date, end_date):
|
33 |
+
repo = Repo(repo_path)
|
34 |
+
start = datetime.strptime(start_date, '%Y-%m-%d')
|
35 |
+
end = datetime.strptime(end_date, '%Y-%m-%d')
|
36 |
+
|
37 |
+
changed_files = set()
|
38 |
+
pbar = tqdm(repo.iter_commits(), desc=f"Reading commits from {end_date} to {start_date}")
|
39 |
+
for commit in pbar:
|
40 |
+
commit_date = datetime.fromtimestamp(commit.committed_date)
|
41 |
+
pbar.set_postfix_str(f"Commit date: {commit_date}")
|
42 |
+
if start <= commit_date <= end:
|
43 |
+
changed_files.update(item.a_path for item in commit.diff(commit.parents[0]))
|
44 |
+
|
45 |
+
if commit_date < start:
|
46 |
+
break
|
47 |
+
|
48 |
+
return changed_files
|
49 |
+
|
50 |
+
|
51 |
+
def read_json(repo_path, file):
|
52 |
+
with open(f"{repo_path}/{file}") as file:
|
53 |
+
return json.load(file)
|
54 |
+
|
55 |
+
|
56 |
+
def write_json(repo_path, file, content):
|
57 |
+
with open(f"{repo_path}/{file}", "w") as file:
|
58 |
+
json.dump(content, file, indent=2)
|
59 |
+
|
60 |
+
|
61 |
+
def main():
|
62 |
+
requests_path = "/Users/lozowski/Developer/requests"
|
63 |
+
start_date = "2024-12-09"
|
64 |
+
end_date = "2025-01-07"
|
65 |
+
|
66 |
+
changed_files = get_changed_files(requests_path, start_date, end_date)
|
67 |
+
|
68 |
+
for file in tqdm(changed_files):
|
69 |
+
try:
|
70 |
+
request_data = read_json(requests_path, file)
|
71 |
+
except FileNotFoundError as e:
|
72 |
+
tqdm.write(f"File {file} not found")
|
73 |
+
continue
|
74 |
+
|
75 |
+
try:
|
76 |
+
model_info = API.model_info(
|
77 |
+
repo_id=request_data["model"],
|
78 |
+
revision=request_data["revision"],
|
79 |
+
token=HF_TOKEN
|
80 |
+
)
|
81 |
+
except (RepositoryNotFoundError, RevisionNotFoundError) as e:
|
82 |
+
tqdm.write(f"Model info for {request_data["model"]} not found")
|
83 |
+
continue
|
84 |
+
|
85 |
+
with logging_redirect_tqdm():
|
86 |
+
new_model_size, error = asyncio.run(validator.get_model_size(
|
87 |
+
model_info=model_info,
|
88 |
+
precision=request_data["precision"],
|
89 |
+
base_model=request_data["base_model"],
|
90 |
+
revision=request_data["revision"]
|
91 |
+
))
|
92 |
+
|
93 |
+
if error:
|
94 |
+
tqdm.write(f"Error getting model size info for {request_data["model"]}, {error}")
|
95 |
+
continue
|
96 |
+
|
97 |
+
old_model_size = request_data["params"]
|
98 |
+
if old_model_size != new_model_size:
|
99 |
+
if new_model_size > 100:
|
100 |
+
tqdm.write(f"Model: {request_data["model"]}, size is more 100B: {new_model_size}")
|
101 |
+
|
102 |
+
tqdm.write(f"Model: {request_data["model"]}, old size: {request_data["params"]} new size: {new_model_size}")
|
103 |
+
tqdm.write(f"Updating request file {file}")
|
104 |
+
|
105 |
+
request_data["params"] = new_model_size
|
106 |
+
write_json(requests_path, file, content=request_data)
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == "__main__":
|
110 |
+
main()
|