mrcuddle commited on
Commit
33e9d70
·
verified ·
1 Parent(s): 963d46a

Update hf_merge.py

Browse files
Files changed (1) hide show
  1. hf_merge.py +105 -166
hf_merge.py CHANGED
@@ -1,156 +1,103 @@
1
- from pathlib import Path
2
- from time import sleep
3
- from tqdm import tqdm
4
- import argparse
5
- import requests
6
- import git
7
- import merge
8
  import os
9
  import shutil
10
- import sys
11
- import yaml
12
- from huggingface_hub import snapshot_download
13
- from huggingface_hub import HfApi, hf_hub_download
14
-
15
- def parse_arguments():
16
- parser = argparse.ArgumentParser(description="Merge HuggingFace models")
17
- parser.add_argument('repo_list', type=str, help='File containing list of repositories to merge, supports mergekit yaml or txt')
18
- parser.add_argument('output_dir', type=str, help='Directory for the merged models')
19
- parser.add_argument('-staging', type=str, default='./staging', help='Path to staging folder')
20
- parser.add_argument('-p', type=float, default=0.5, help='Dropout probability')
21
- parser.add_argument('-lambda', dest='lambda_val', type=float, default=1.0, help='Scaling factor for the weight delta')
22
- parser.add_argument('--dry', action='store_true', help='Run in dry mode without making any changes')
23
- return parser.parse_args()
24
-
25
- def repo_list_generator(file_path, default_p, default_lambda_val):
26
- _, file_extension = os.path.splitext(file_path)
27
-
28
- # Branching based on file extension
29
- if file_extension.lower() == '.yaml' or file_extension.lower() == ".yml":
30
- with open(file_path, 'r') as file:
31
- data = yaml.safe_load(file)
32
-
33
- for model_info in data['models']:
34
- model_name = model_info['model']
35
- p = model_info.get('parameters', {}).get('weight', default_p)
36
- lambda_val = 1 / model_info.get('parameters', {}).get('density', default_lambda_val)
37
- yield model_name, p, lambda_val
38
-
39
- else: # Defaulting to txt file processing
40
- with open(file_path, "r") as file:
41
- repos_to_process = file.readlines()
42
-
43
- for repo in repos_to_process:
44
- yield repo.strip(), default_p, default_lambda_val
45
-
46
- def reset_directories(directories, dry_run):
47
- for directory in directories:
48
- if os.path.exists(directory):
49
- if dry_run:
50
- print(f"[DRY RUN] Would delete directory {directory}")
51
- else:
52
- # Check if the directory is a symlink
53
- if os.path.islink(directory):
54
- os.unlink(directory) # Remove the symlink
55
- else:
56
- shutil.rmtree(directory, ignore_errors=False)
57
- print(f"Directory {directory} deleted successfully.")
58
-
59
- def do_merge(tensor_map, staging_path, p, lambda_val, dry_run=False):
60
- if dry_run:
61
- print(f"[DRY RUN] Would merge with {staging_path}")
62
- else:
63
- try:
64
- print(f"Merge operation for {staging_path}")
65
- tensor_map = merge.merge_folder(tensor_map, staging_path, p, lambda_val)
66
- print("Merge operation completed successfully.")
67
- except Exception as e:
68
- print(f"Error during merge operation: {e}")
69
- return tensor_map
70
-
71
- def download_repo(repo_name, path, dry_run=False):
72
- if not os.path.exists(path):
73
- os.makedirs(path)
74
-
75
- api = HfApi()
76
-
77
- # Get the list of all files in the repository using HfApi
78
- repo_files = api.list_repo_files(repo_name)
79
-
80
- if dry_run:
81
- print(f"[DRY RUN] Would download the following files from {repo_name} to {path}:")
82
- for file_path in repo_files:
83
- print(file_path)
84
- else:
85
- print(f"Downloading the entire repository {repo_name} directly to {path}.")
86
-
87
- for file_path in repo_files:
88
- print(f"Downloading {path}/{file_path}...")
89
-
90
- # Download each file directly to the specified path
91
- hf_hub_download(
92
- repo_id=repo_name,
93
- filename=file_path,
94
- cache_dir=path,
95
- local_dir=path, # Store directly in the target directory
96
- local_dir_use_symlinks=False # Ensure symlinks are not used
97
- )
98
-
99
- print(f"Repository {repo_name} downloaded successfully to {path}.")
100
-
101
- def should_create_symlink(repo_name):
102
- if os.path.exists(repo_name):
103
- return True, os.path.isfile(repo_name)
104
- return False, False
105
-
106
- def download_or_link_repo(repo_name, path, dry_run=False):
107
- symlink, is_file = should_create_symlink(repo_name)
108
-
109
- if symlink and is_file:
110
- os.makedirs(path, exist_ok=True)
111
- symlink_path = os.path.join(path, os.path.basename(repo_name))
112
- os.symlink(repo_name, symlink_path)
113
- elif symlink:
114
- os.symlink(repo_name, path)
115
- else:
116
- download_repo(repo_name, path, dry_run)
117
-
118
- def delete_repo(path, dry_run=False):
119
- if dry_run:
120
- print(f"[DRY RUN] Would delete repository at {path}")
121
- else:
122
  try:
123
- shutil.rmtree(path)
124
- print(f"Repository at {path} deleted successfully.")
125
  except Exception as e:
126
- print(f"Error deleting repository at {path}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  def get_max_vocab_size(repo_list):
129
  max_vocab_size = 0
130
  repo_with_max_vocab = None
 
131
 
132
- for repo in repo_list:
133
- repo_name = repo[0].strip()
134
- url = f"https://huggingface.co/{repo_name}/raw/main/config.json"
135
-
136
  try:
137
  response = requests.get(url)
138
- response.raise_for_status()
139
  config = response.json()
140
- vocab_size = config.get("vocab_size", 0)
141
-
142
  if vocab_size > max_vocab_size:
143
  max_vocab_size = vocab_size
144
  repo_with_max_vocab = repo_name
145
-
146
  except requests.RequestException as e:
147
- print(f"Error fetching data from {url}: {e}")
148
 
149
  return max_vocab_size, repo_with_max_vocab
150
 
151
  def download_json_files(repo_name, file_paths, output_dir):
152
  base_url = f"https://huggingface.co/{repo_name}/raw/main/"
153
-
154
  for file_path in file_paths:
155
  url = base_url + file_path
156
  response = requests.get(url)
@@ -158,48 +105,40 @@ def download_json_files(repo_name, file_paths, output_dir):
158
  with open(os.path.join(output_dir, os.path.basename(file_path)), 'wb') as file:
159
  file.write(response.content)
160
  else:
161
- print(f"Failed to download {file_path}")
162
 
163
- def process_repos(output_dir, base_model, staging_model, repo_list_file, p, lambda_val, dry_run=False):
164
- # Check if output_dir exists
165
- if os.path.exists(output_dir):
166
- sys.exit(f"Output directory '{output_dir}' already exists. Exiting to prevent data loss.")
167
 
168
- # Reset base and staging directories
169
- reset_directories([base_model, staging_model], dry_run)
 
 
 
 
 
 
 
 
 
 
170
 
171
- repo_list_gen = repo_list_generator(repo_list_file, p, lambda_val)
 
172
 
173
- repos_to_process = list(repo_list_gen)
 
174
 
175
- # Initial download for 'base_model'
176
- download_or_link_repo(repos_to_process[0][0].strip(), base_model, dry_run)
177
- tensor_map = merge.map_tensors_to_files(base_model)
178
 
179
- for i, repo in enumerate(tqdm(repos_to_process[1:], desc='Merging Repos')):
180
- repo_name = repo[0].strip()
181
- repo_p = repo[1]
182
- repo_lambda = repo[2]
183
- delete_repo(staging_model, dry_run)
184
- download_or_link_repo(repo_name, staging_model, dry_run)
185
- tensor_map = do_merge(tensor_map, staging_model, repo_p, repo_lambda, dry_run)
186
 
187
- os.makedirs(output_dir, exist_ok=True)
188
- merge.copy_nontensor_files(base_model, output_dir)
189
 
190
- # Handle LLMs that add tokens by taking the largest
191
- if os.path.exists(os.path.join(output_dir, 'config.json')):
192
- max_vocab_size, repo_name = get_max_vocab_size(repos_to_process)
193
- if max_vocab_size > 0:
194
- file_paths = ['config.json', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json']
195
- download_json_files(repo_name, file_paths, output_dir)
196
 
197
- reset_directories([base_model, staging_model], dry_run)
198
- merge.save_tensor_map(tensor_map, output_dir)
199
 
200
  if __name__ == "__main__":
201
- args = parse_arguments()
202
- staging_path = Path(args.staging)
203
- os.makedirs(args.staging, exist_ok=True)
204
- process_repos(args.output_dir, staging_path / 'base_model', staging_path / 'staging_model', args.repo_list, args.p, args.lambda_val, args.dry)
205
-
 
 
 
 
 
 
 
 
1
  import os
2
  import shutil
3
+ import argparse
4
+ import requests
5
+ from tqdm import tqdm
6
+ from huggingface_hub import HfApi, Repository, hf_hub_download, upload_folder
7
+ from merge import merge_folder, map_tensors_to_files, copy_nontensor_files, save_tensor_map
8
+
9
+ class RepositoryManager:
10
+ def __init__(self, repo_id=None, token=None, dry_run=False):
11
+ self.repo_id = repo_id
12
+ self.token = token
13
+ self.dry_run = dry_run
14
+ self.api = HfApi(token=token) if token else HfApi()
15
+
16
+ def download_repo(self, repo_name, path):
17
+ if self.dry_run:
18
+ print(f"[DRY RUN] Downloading {repo_name} to {path}")
19
+ return
20
+
21
+ if not os.path.exists(path):
22
+ os.makedirs(path)
23
+
24
+ repo_files = self.api.list_repo_files(repo_name)
25
+
26
+ for file_path in tqdm(repo_files, desc=f"Downloading {repo_name}"):
27
+ file_url = f"https://huggingface.co/{repo_name}/resolve/main/{file_path}"
28
+ hf_hub_download(repo_id=repo_name, filename=file_path, cache_dir=path, local_dir=path)
29
+
30
+ def delete_repo(self, path):
31
+ if self.dry_run:
32
+ print(f"[DRY RUN] Deleting {path}")
33
+ else:
34
+ shutil.rmtree(path, ignore_errors=True)
35
+ print(f"Deleted {path}")
36
+
37
+ class ModelMerger:
38
+ def __init__(self, staging_path, repo_id=None, token=None, dry_run=False):
39
+ self.staging_path = staging_path
40
+ self.repo_id = repo_id
41
+ self.token = token
42
+ self.dry_run = dry_run
43
+ self.tensor_map = None
44
+ self.api = HfApi(token=token) if token else HfApi()
45
+
46
+ def prepare_base_model(self, base_model_name, base_model_path):
47
+ repo_manager = RepositoryManager(self.repo_id, self.token, self.dry_run)
48
+ repo_manager.download_repo(base_model_name, base_model_path)
49
+ self.tensor_map = map_tensors_to_files(base_model_path)
50
+
51
+ def merge_repo(self, repo_name, repo_path, p, lambda_val):
52
+ repo_manager = RepositoryManager(self.repo_id, self.token, self.dry_run)
53
+ repo_manager.delete_repo(repo_path)
54
+ repo_manager.download_repo(repo_name, repo_path)
55
+
56
+ if self.dry_run:
57
+ print(f"[DRY RUN] Merging {repo_name} with p={p} and lambda={lambda_val}")
58
+ return
59
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  try:
61
+ self.tensor_map = merge_folder(self.tensor_map, repo_path, p, lambda_val)
62
+ print(f"Merged {repo_name}")
63
  except Exception as e:
64
+ print(f"Error merging {repo_name}: {e}")
65
+
66
+ def finalize_merge(self, output_dir):
67
+ copy_nontensor_files(self.staging_path / 'base_model', output_dir)
68
+ save_tensor_map(self.tensor_map, output_dir)
69
+
70
+ def upload_model(self, output_dir, repo_name, commit_message):
71
+ if self.dry_run:
72
+ print(f"[DRY RUN] Uploading model to {repo_name}")
73
+ return
74
+
75
+ repo = Repository(repo_id=self.repo_id, token=self.token)
76
+ repo.create_branch("main", "main") # Ensure main branch exists
77
+ repo.upload_folder(output_dir, repo_path=repo_name, commit_message=commit_message)
78
+ print(f"Model uploaded to {repo_name}")
79
 
80
  def get_max_vocab_size(repo_list):
81
  max_vocab_size = 0
82
  repo_with_max_vocab = None
83
+ base_url = "https://huggingface.co/{}/raw/main/config.json"
84
 
85
+ for repo_name, _, _ in repo_list:
86
+ url = base_url.format(repo_name)
 
 
87
  try:
88
  response = requests.get(url)
 
89
  config = response.json()
90
+ vocab_size = config.get('vocab_size', 0)
 
91
  if vocab_size > max_vocab_size:
92
  max_vocab_size = vocab_size
93
  repo_with_max_vocab = repo_name
 
94
  except requests.RequestException as e:
95
+ print(f"Error fetching vocab size from {repo_name}: {e}")
96
 
97
  return max_vocab_size, repo_with_max_vocab
98
 
99
  def download_json_files(repo_name, file_paths, output_dir):
100
  base_url = f"https://huggingface.co/{repo_name}/raw/main/"
 
101
  for file_path in file_paths:
102
  url = base_url + file_path
103
  response = requests.get(url)
 
105
  with open(os.path.join(output_dir, os.path.basename(file_path)), 'wb') as file:
106
  file.write(response.content)
107
  else:
108
+ print(f"Failed to download {file_path} from {repo_name}")
109
 
 
 
 
 
110
 
111
+ def main():
112
+ parser = argparse.ArgumentParser(description="Merge and upload HuggingFace models")
113
+ parser.add_argument('repos', nargs='+', help='Repositories to merge')
114
+ parser.add_argument('output_dir', help='Output directory')
115
+ parser.add_argument('-staging', default='./staging', help='Staging folder')
116
+ parser.add_argument('-p', type=float, default=0.5, help='Dropout probability')
117
+ parser.add_argument('-lambda', type=float, default=1.0, help='Scaling factor')
118
+ parser.add_argument('--dry', action='store_true', help='Dry run mode')
119
+ parser.add_argument('--token', type=str, help='HuggingFace token')
120
+ parser.add_argument('--repo', type=str, help='HuggingFace repo to upload to')
121
+ parser.add_argument('--commit-message', type=str, default='Upload merged model', help='Commit message for model upload')
122
+ args = parser.parse_args()
123
 
124
+ staging_path = os.path.abspath(args.staging)
125
+ os.makedirs(staging_path, exist_ok=True)
126
 
127
+ base_model_name, base_model_path = "base_model", os.path.join(staging_path, "base_model")
128
+ staging_model_path = os.path.join(staging_path, "staging_model")
129
 
130
+ model_merger = ModelMerger(staging_path, args.repo, args.token, args.dry)
131
+ model_merger.prepare_base_model(base_model_name, base_model_path)
 
132
 
133
+ for repo_name in tqdm(args.repos[1:], desc="Merging Repos"):
134
+ model_merger.merge_repo(repo_name, staging_model_path, args.p, args.lambda)
 
 
 
 
 
135
 
136
+ model_merger.finalize_merge(args.output_dir)
 
137
 
138
+ max_vocab_size, _ = get_max_vocab_size(args.repos) # Unused variable removed
 
 
 
 
 
139
 
140
+ if args.repo:
141
+ model_merger.upload_model(args.output_dir, args.repo, args.commit_message)
142
 
143
  if __name__ == "__main__":
144
+ main()