mrcuddle commited on
Commit
8e34ad1
·
verified ·
1 Parent(s): 023d71e

Upload hf_merge.py

Browse files
Files changed (1) hide show
  1. hf_merge.py +205 -0
hf_merge.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from time import sleep
3
+ from tqdm import tqdm
4
+ import argparse
5
+ import requests
6
+ import git
7
+ import merge
8
+ import os
9
+ import shutil
10
+ import sys
11
+ import yaml
12
+ from huggingface_hub import snapshot_download
13
+ from huggingface_hub import HfApi, hf_hub_download
14
+
15
+ def parse_arguments():
16
+ parser = argparse.ArgumentParser(description="Merge HuggingFace models")
17
+ parser.add_argument('repo_list', type=str, help='File containing list of repositories to merge, supports mergekit yaml or txt')
18
+ parser.add_argument('output_dir', type=str, help='Directory for the merged models')
19
+ parser.add_argument('-staging', type=str, default='./staging', help='Path to staging folder')
20
+ parser.add_argument('-p', type=float, default=0.5, help='Dropout probability')
21
+ parser.add_argument('-lambda', dest='lambda_val', type=float, default=1.0, help='Scaling factor for the weight delta')
22
+ parser.add_argument('--dry', action='store_true', help='Run in dry mode without making any changes')
23
+ return parser.parse_args()
24
+
25
+ def repo_list_generator(file_path, default_p, default_lambda_val):
26
+ _, file_extension = os.path.splitext(file_path)
27
+
28
+ # Branching based on file extension
29
+ if file_extension.lower() == '.yaml' or file_extension.lower() == ".yml":
30
+ with open(file_path, 'r') as file:
31
+ data = yaml.safe_load(file)
32
+
33
+ for model_info in data['models']:
34
+ model_name = model_info['model']
35
+ p = model_info.get('parameters', {}).get('weight', default_p)
36
+ lambda_val = 1 / model_info.get('parameters', {}).get('density', default_lambda_val)
37
+ yield model_name, p, lambda_val
38
+
39
+ else: # Defaulting to txt file processing
40
+ with open(file_path, "r") as file:
41
+ repos_to_process = file.readlines()
42
+
43
+ for repo in repos_to_process:
44
+ yield repo.strip(), default_p, default_lambda_val
45
+
46
+ def reset_directories(directories, dry_run):
47
+ for directory in directories:
48
+ if os.path.exists(directory):
49
+ if dry_run:
50
+ print(f"[DRY RUN] Would delete directory {directory}")
51
+ else:
52
+ # Check if the directory is a symlink
53
+ if os.path.islink(directory):
54
+ os.unlink(directory) # Remove the symlink
55
+ else:
56
+ shutil.rmtree(directory, ignore_errors=False)
57
+ print(f"Directory {directory} deleted successfully.")
58
+
59
+ def do_merge(tensor_map, staging_path, p, lambda_val, dry_run=False):
60
+ if dry_run:
61
+ print(f"[DRY RUN] Would merge with {staging_path}")
62
+ else:
63
+ try:
64
+ print(f"Merge operation for {staging_path}")
65
+ tensor_map = merge.merge_folder(tensor_map, staging_path, p, lambda_val)
66
+ print("Merge operation completed successfully.")
67
+ except Exception as e:
68
+ print(f"Error during merge operation: {e}")
69
+ return tensor_map
70
+
71
+ def download_repo(repo_name, path, dry_run=False):
72
+ if not os.path.exists(path):
73
+ os.makedirs(path)
74
+
75
+ api = HfApi()
76
+
77
+ # Get the list of all files in the repository using HfApi
78
+ repo_files = api.list_repo_files(repo_name)
79
+
80
+ if dry_run:
81
+ print(f"[DRY RUN] Would download the following files from {repo_name} to {path}:")
82
+ for file_path in repo_files:
83
+ print(file_path)
84
+ else:
85
+ print(f"Downloading the entire repository {repo_name} directly to {path}.")
86
+
87
+ for file_path in repo_files:
88
+ print(f"Downloading {path}/{file_path}...")
89
+
90
+ # Download each file directly to the specified path
91
+ hf_hub_download(
92
+ repo_id=repo_name,
93
+ filename=file_path,
94
+ cache_dir=path,
95
+ local_dir=path, # Store directly in the target directory
96
+ local_dir_use_symlinks=False # Ensure symlinks are not used
97
+ )
98
+
99
+ print(f"Repository {repo_name} downloaded successfully to {path}.")
100
+
101
+ def should_create_symlink(repo_name):
102
+ if os.path.exists(repo_name):
103
+ return True, os.path.isfile(repo_name)
104
+ return False, False
105
+
106
+ def download_or_link_repo(repo_name, path, dry_run=False):
107
+ symlink, is_file = should_create_symlink(repo_name)
108
+
109
+ if symlink and is_file:
110
+ os.makedirs(path, exist_ok=True)
111
+ symlink_path = os.path.join(path, os.path.basename(repo_name))
112
+ os.symlink(repo_name, symlink_path)
113
+ elif symlink:
114
+ os.symlink(repo_name, path)
115
+ else:
116
+ download_repo(repo_name, path, dry_run)
117
+
118
+ def delete_repo(path, dry_run=False):
119
+ if dry_run:
120
+ print(f"[DRY RUN] Would delete repository at {path}")
121
+ else:
122
+ try:
123
+ shutil.rmtree(path)
124
+ print(f"Repository at {path} deleted successfully.")
125
+ except Exception as e:
126
+ print(f"Error deleting repository at {path}: {e}")
127
+
128
+ def get_max_vocab_size(repo_list):
129
+ max_vocab_size = 0
130
+ repo_with_max_vocab = None
131
+
132
+ for repo in repo_list:
133
+ repo_name = repo[0].strip()
134
+ url = f"https://huggingface.co/{repo_name}/raw/main/config.json"
135
+
136
+ try:
137
+ response = requests.get(url)
138
+ response.raise_for_status()
139
+ config = response.json()
140
+ vocab_size = config.get("vocab_size", 0)
141
+
142
+ if vocab_size > max_vocab_size:
143
+ max_vocab_size = vocab_size
144
+ repo_with_max_vocab = repo_name
145
+
146
+ except requests.RequestException as e:
147
+ print(f"Error fetching data from {url}: {e}")
148
+
149
+ return max_vocab_size, repo_with_max_vocab
150
+
151
+ def download_json_files(repo_name, file_paths, output_dir):
152
+ base_url = f"https://huggingface.co/{repo_name}/raw/main/"
153
+
154
+ for file_path in file_paths:
155
+ url = base_url + file_path
156
+ response = requests.get(url)
157
+ if response.status_code == 200:
158
+ with open(os.path.join(output_dir, os.path.basename(file_path)), 'wb') as file:
159
+ file.write(response.content)
160
+ else:
161
+ print(f"Failed to download {file_path}")
162
+
163
+ def process_repos(output_dir, base_model, staging_model, repo_list_file, p, lambda_val, dry_run=False):
164
+ # Check if output_dir exists
165
+ if os.path.exists(output_dir):
166
+ sys.exit(f"Output directory '{output_dir}' already exists. Exiting to prevent data loss.")
167
+
168
+ # Reset base and staging directories
169
+ reset_directories([base_model, staging_model], dry_run)
170
+
171
+ repo_list_gen = repo_list_generator(repo_list_file, p, lambda_val)
172
+
173
+ repos_to_process = list(repo_list_gen)
174
+
175
+ # Initial download for 'base_model'
176
+ download_or_link_repo(repos_to_process[0][0].strip(), base_model, dry_run)
177
+ tensor_map = merge.map_tensors_to_files(base_model)
178
+
179
+ for i, repo in enumerate(tqdm(repos_to_process[1:], desc='Merging Repos')):
180
+ repo_name = repo[0].strip()
181
+ repo_p = repo[1]
182
+ repo_lambda = repo[2]
183
+ delete_repo(staging_model, dry_run)
184
+ download_or_link_repo(repo_name, staging_model, dry_run)
185
+ tensor_map = do_merge(tensor_map, staging_model, repo_p, repo_lambda, dry_run)
186
+
187
+ os.makedirs(output_dir, exist_ok=True)
188
+ merge.copy_nontensor_files(base_model, output_dir)
189
+
190
+ # Handle LLMs that add tokens by taking the largest
191
+ if os.path.exists(os.path.join(output_dir, 'config.json')):
192
+ max_vocab_size, repo_name = get_max_vocab_size(repos_to_process)
193
+ if max_vocab_size > 0:
194
+ file_paths = ['config.json', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json']
195
+ download_json_files(repo_name, file_paths, output_dir)
196
+
197
+ reset_directories([base_model, staging_model], dry_run)
198
+ merge.save_tensor_map(tensor_map, output_dir)
199
+
200
+ if __name__ == "__main__":
201
+ args = parse_arguments()
202
+ staging_path = Path(args.staging)
203
+ os.makedirs(args.staging, exist_ok=True)
204
+ process_repos(args.output_dir, staging_path / 'base_model', staging_path / 'staging_model', args.repo_list, args.p, args.lambda_val, args.dry)
205
+