Spaces:
Running
Running
Upload hf_merge.py
Browse files- hf_merge.py +205 -0
hf_merge.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from time import sleep
|
3 |
+
from tqdm import tqdm
|
4 |
+
import argparse
|
5 |
+
import requests
|
6 |
+
import git
|
7 |
+
import merge
|
8 |
+
import os
|
9 |
+
import shutil
|
10 |
+
import sys
|
11 |
+
import yaml
|
12 |
+
from huggingface_hub import snapshot_download
|
13 |
+
from huggingface_hub import HfApi, hf_hub_download
|
14 |
+
|
15 |
+
def parse_arguments():
|
16 |
+
parser = argparse.ArgumentParser(description="Merge HuggingFace models")
|
17 |
+
parser.add_argument('repo_list', type=str, help='File containing list of repositories to merge, supports mergekit yaml or txt')
|
18 |
+
parser.add_argument('output_dir', type=str, help='Directory for the merged models')
|
19 |
+
parser.add_argument('-staging', type=str, default='./staging', help='Path to staging folder')
|
20 |
+
parser.add_argument('-p', type=float, default=0.5, help='Dropout probability')
|
21 |
+
parser.add_argument('-lambda', dest='lambda_val', type=float, default=1.0, help='Scaling factor for the weight delta')
|
22 |
+
parser.add_argument('--dry', action='store_true', help='Run in dry mode without making any changes')
|
23 |
+
return parser.parse_args()
|
24 |
+
|
25 |
+
def repo_list_generator(file_path, default_p, default_lambda_val):
|
26 |
+
_, file_extension = os.path.splitext(file_path)
|
27 |
+
|
28 |
+
# Branching based on file extension
|
29 |
+
if file_extension.lower() == '.yaml' or file_extension.lower() == ".yml":
|
30 |
+
with open(file_path, 'r') as file:
|
31 |
+
data = yaml.safe_load(file)
|
32 |
+
|
33 |
+
for model_info in data['models']:
|
34 |
+
model_name = model_info['model']
|
35 |
+
p = model_info.get('parameters', {}).get('weight', default_p)
|
36 |
+
lambda_val = 1 / model_info.get('parameters', {}).get('density', default_lambda_val)
|
37 |
+
yield model_name, p, lambda_val
|
38 |
+
|
39 |
+
else: # Defaulting to txt file processing
|
40 |
+
with open(file_path, "r") as file:
|
41 |
+
repos_to_process = file.readlines()
|
42 |
+
|
43 |
+
for repo in repos_to_process:
|
44 |
+
yield repo.strip(), default_p, default_lambda_val
|
45 |
+
|
46 |
+
def reset_directories(directories, dry_run):
|
47 |
+
for directory in directories:
|
48 |
+
if os.path.exists(directory):
|
49 |
+
if dry_run:
|
50 |
+
print(f"[DRY RUN] Would delete directory {directory}")
|
51 |
+
else:
|
52 |
+
# Check if the directory is a symlink
|
53 |
+
if os.path.islink(directory):
|
54 |
+
os.unlink(directory) # Remove the symlink
|
55 |
+
else:
|
56 |
+
shutil.rmtree(directory, ignore_errors=False)
|
57 |
+
print(f"Directory {directory} deleted successfully.")
|
58 |
+
|
59 |
+
def do_merge(tensor_map, staging_path, p, lambda_val, dry_run=False):
|
60 |
+
if dry_run:
|
61 |
+
print(f"[DRY RUN] Would merge with {staging_path}")
|
62 |
+
else:
|
63 |
+
try:
|
64 |
+
print(f"Merge operation for {staging_path}")
|
65 |
+
tensor_map = merge.merge_folder(tensor_map, staging_path, p, lambda_val)
|
66 |
+
print("Merge operation completed successfully.")
|
67 |
+
except Exception as e:
|
68 |
+
print(f"Error during merge operation: {e}")
|
69 |
+
return tensor_map
|
70 |
+
|
71 |
+
def download_repo(repo_name, path, dry_run=False):
|
72 |
+
if not os.path.exists(path):
|
73 |
+
os.makedirs(path)
|
74 |
+
|
75 |
+
api = HfApi()
|
76 |
+
|
77 |
+
# Get the list of all files in the repository using HfApi
|
78 |
+
repo_files = api.list_repo_files(repo_name)
|
79 |
+
|
80 |
+
if dry_run:
|
81 |
+
print(f"[DRY RUN] Would download the following files from {repo_name} to {path}:")
|
82 |
+
for file_path in repo_files:
|
83 |
+
print(file_path)
|
84 |
+
else:
|
85 |
+
print(f"Downloading the entire repository {repo_name} directly to {path}.")
|
86 |
+
|
87 |
+
for file_path in repo_files:
|
88 |
+
print(f"Downloading {path}/{file_path}...")
|
89 |
+
|
90 |
+
# Download each file directly to the specified path
|
91 |
+
hf_hub_download(
|
92 |
+
repo_id=repo_name,
|
93 |
+
filename=file_path,
|
94 |
+
cache_dir=path,
|
95 |
+
local_dir=path, # Store directly in the target directory
|
96 |
+
local_dir_use_symlinks=False # Ensure symlinks are not used
|
97 |
+
)
|
98 |
+
|
99 |
+
print(f"Repository {repo_name} downloaded successfully to {path}.")
|
100 |
+
|
101 |
+
def should_create_symlink(repo_name):
|
102 |
+
if os.path.exists(repo_name):
|
103 |
+
return True, os.path.isfile(repo_name)
|
104 |
+
return False, False
|
105 |
+
|
106 |
+
def download_or_link_repo(repo_name, path, dry_run=False):
|
107 |
+
symlink, is_file = should_create_symlink(repo_name)
|
108 |
+
|
109 |
+
if symlink and is_file:
|
110 |
+
os.makedirs(path, exist_ok=True)
|
111 |
+
symlink_path = os.path.join(path, os.path.basename(repo_name))
|
112 |
+
os.symlink(repo_name, symlink_path)
|
113 |
+
elif symlink:
|
114 |
+
os.symlink(repo_name, path)
|
115 |
+
else:
|
116 |
+
download_repo(repo_name, path, dry_run)
|
117 |
+
|
118 |
+
def delete_repo(path, dry_run=False):
|
119 |
+
if dry_run:
|
120 |
+
print(f"[DRY RUN] Would delete repository at {path}")
|
121 |
+
else:
|
122 |
+
try:
|
123 |
+
shutil.rmtree(path)
|
124 |
+
print(f"Repository at {path} deleted successfully.")
|
125 |
+
except Exception as e:
|
126 |
+
print(f"Error deleting repository at {path}: {e}")
|
127 |
+
|
128 |
+
def get_max_vocab_size(repo_list):
|
129 |
+
max_vocab_size = 0
|
130 |
+
repo_with_max_vocab = None
|
131 |
+
|
132 |
+
for repo in repo_list:
|
133 |
+
repo_name = repo[0].strip()
|
134 |
+
url = f"https://huggingface.co/{repo_name}/raw/main/config.json"
|
135 |
+
|
136 |
+
try:
|
137 |
+
response = requests.get(url)
|
138 |
+
response.raise_for_status()
|
139 |
+
config = response.json()
|
140 |
+
vocab_size = config.get("vocab_size", 0)
|
141 |
+
|
142 |
+
if vocab_size > max_vocab_size:
|
143 |
+
max_vocab_size = vocab_size
|
144 |
+
repo_with_max_vocab = repo_name
|
145 |
+
|
146 |
+
except requests.RequestException as e:
|
147 |
+
print(f"Error fetching data from {url}: {e}")
|
148 |
+
|
149 |
+
return max_vocab_size, repo_with_max_vocab
|
150 |
+
|
151 |
+
def download_json_files(repo_name, file_paths, output_dir):
|
152 |
+
base_url = f"https://huggingface.co/{repo_name}/raw/main/"
|
153 |
+
|
154 |
+
for file_path in file_paths:
|
155 |
+
url = base_url + file_path
|
156 |
+
response = requests.get(url)
|
157 |
+
if response.status_code == 200:
|
158 |
+
with open(os.path.join(output_dir, os.path.basename(file_path)), 'wb') as file:
|
159 |
+
file.write(response.content)
|
160 |
+
else:
|
161 |
+
print(f"Failed to download {file_path}")
|
162 |
+
|
163 |
+
def process_repos(output_dir, base_model, staging_model, repo_list_file, p, lambda_val, dry_run=False):
|
164 |
+
# Check if output_dir exists
|
165 |
+
if os.path.exists(output_dir):
|
166 |
+
sys.exit(f"Output directory '{output_dir}' already exists. Exiting to prevent data loss.")
|
167 |
+
|
168 |
+
# Reset base and staging directories
|
169 |
+
reset_directories([base_model, staging_model], dry_run)
|
170 |
+
|
171 |
+
repo_list_gen = repo_list_generator(repo_list_file, p, lambda_val)
|
172 |
+
|
173 |
+
repos_to_process = list(repo_list_gen)
|
174 |
+
|
175 |
+
# Initial download for 'base_model'
|
176 |
+
download_or_link_repo(repos_to_process[0][0].strip(), base_model, dry_run)
|
177 |
+
tensor_map = merge.map_tensors_to_files(base_model)
|
178 |
+
|
179 |
+
for i, repo in enumerate(tqdm(repos_to_process[1:], desc='Merging Repos')):
|
180 |
+
repo_name = repo[0].strip()
|
181 |
+
repo_p = repo[1]
|
182 |
+
repo_lambda = repo[2]
|
183 |
+
delete_repo(staging_model, dry_run)
|
184 |
+
download_or_link_repo(repo_name, staging_model, dry_run)
|
185 |
+
tensor_map = do_merge(tensor_map, staging_model, repo_p, repo_lambda, dry_run)
|
186 |
+
|
187 |
+
os.makedirs(output_dir, exist_ok=True)
|
188 |
+
merge.copy_nontensor_files(base_model, output_dir)
|
189 |
+
|
190 |
+
# Handle LLMs that add tokens by taking the largest
|
191 |
+
if os.path.exists(os.path.join(output_dir, 'config.json')):
|
192 |
+
max_vocab_size, repo_name = get_max_vocab_size(repos_to_process)
|
193 |
+
if max_vocab_size > 0:
|
194 |
+
file_paths = ['config.json', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json']
|
195 |
+
download_json_files(repo_name, file_paths, output_dir)
|
196 |
+
|
197 |
+
reset_directories([base_model, staging_model], dry_run)
|
198 |
+
merge.save_tensor_map(tensor_map, output_dir)
|
199 |
+
|
200 |
+
if __name__ == "__main__":
|
201 |
+
args = parse_arguments()
|
202 |
+
staging_path = Path(args.staging)
|
203 |
+
os.makedirs(args.staging, exist_ok=True)
|
204 |
+
process_repos(args.output_dir, staging_path / 'base_model', staging_path / 'staging_model', args.repo_list, args.p, args.lambda_val, args.dry)
|
205 |
+
|