Spaces:
Running
Running
File size: 5,942 Bytes
8e34ad1 33e9d70 8e34ad1 33e9d70 8e34ad1 33e9d70 8e34ad1 33e9d70 8e34ad1 33e9d70 8e34ad1 33e9d70 8e34ad1 33e9d70 8e34ad1 33e9d70 8e34ad1 33e9d70 8e34ad1 33e9d70 8e34ad1 33e9d70 8e34ad1 33e9d70 8e34ad1 33e9d70 8e34ad1 33e9d70 8e34ad1 33e9d70 8e34ad1 33e9d70 8e34ad1 33e9d70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import os
import shutil
import argparse
import requests
from tqdm import tqdm
from huggingface_hub import HfApi, Repository, hf_hub_download, upload_folder
from merge import merge_folder, map_tensors_to_files, copy_nontensor_files, save_tensor_map
class RepositoryManager:
def __init__(self, repo_id=None, token=None, dry_run=False):
self.repo_id = repo_id
self.token = token
self.dry_run = dry_run
self.api = HfApi(token=token) if token else HfApi()
def download_repo(self, repo_name, path):
if self.dry_run:
print(f"[DRY RUN] Downloading {repo_name} to {path}")
return
if not os.path.exists(path):
os.makedirs(path)
repo_files = self.api.list_repo_files(repo_name)
for file_path in tqdm(repo_files, desc=f"Downloading {repo_name}"):
file_url = f"https://huggingface.co/{repo_name}/resolve/main/{file_path}"
hf_hub_download(repo_id=repo_name, filename=file_path, cache_dir=path, local_dir=path)
def delete_repo(self, path):
if self.dry_run:
print(f"[DRY RUN] Deleting {path}")
else:
shutil.rmtree(path, ignore_errors=True)
print(f"Deleted {path}")
class ModelMerger:
def __init__(self, staging_path, repo_id=None, token=None, dry_run=False):
self.staging_path = staging_path
self.repo_id = repo_id
self.token = token
self.dry_run = dry_run
self.tensor_map = None
self.api = HfApi(token=token) if token else HfApi()
def prepare_base_model(self, base_model_name, base_model_path):
repo_manager = RepositoryManager(self.repo_id, self.token, self.dry_run)
repo_manager.download_repo(base_model_name, base_model_path)
self.tensor_map = map_tensors_to_files(base_model_path)
def merge_repo(self, repo_name, repo_path, p, lambda_val):
repo_manager = RepositoryManager(self.repo_id, self.token, self.dry_run)
repo_manager.delete_repo(repo_path)
repo_manager.download_repo(repo_name, repo_path)
if self.dry_run:
print(f"[DRY RUN] Merging {repo_name} with p={p} and lambda={lambda_val}")
return
try:
self.tensor_map = merge_folder(self.tensor_map, repo_path, p, lambda_val)
print(f"Merged {repo_name}")
except Exception as e:
print(f"Error merging {repo_name}: {e}")
def finalize_merge(self, output_dir):
copy_nontensor_files(self.staging_path / 'base_model', output_dir)
save_tensor_map(self.tensor_map, output_dir)
def upload_model(self, output_dir, repo_name, commit_message):
if self.dry_run:
print(f"[DRY RUN] Uploading model to {repo_name}")
return
repo = Repository(repo_id=self.repo_id, token=self.token)
repo.create_branch("main", "main") # Ensure main branch exists
repo.upload_folder(output_dir, repo_path=repo_name, commit_message=commit_message)
print(f"Model uploaded to {repo_name}")
def get_max_vocab_size(repo_list):
max_vocab_size = 0
repo_with_max_vocab = None
base_url = "https://huggingface.co/{}/raw/main/config.json"
for repo_name, _, _ in repo_list:
url = base_url.format(repo_name)
try:
response = requests.get(url)
config = response.json()
vocab_size = config.get('vocab_size', 0)
if vocab_size > max_vocab_size:
max_vocab_size = vocab_size
repo_with_max_vocab = repo_name
except requests.RequestException as e:
print(f"Error fetching vocab size from {repo_name}: {e}")
return max_vocab_size, repo_with_max_vocab
def download_json_files(repo_name, file_paths, output_dir):
base_url = f"https://huggingface.co/{repo_name}/raw/main/"
for file_path in file_paths:
url = base_url + file_path
response = requests.get(url)
if response.status_code == 200:
with open(os.path.join(output_dir, os.path.basename(file_path)), 'wb') as file:
file.write(response.content)
else:
print(f"Failed to download {file_path} from {repo_name}")
def main():
parser = argparse.ArgumentParser(description="Merge and upload HuggingFace models")
parser.add_argument('repos', nargs='+', help='Repositories to merge')
parser.add_argument('output_dir', help='Output directory')
parser.add_argument('-staging', default='./staging', help='Staging folder')
parser.add_argument('-p', type=float, default=0.5, help='Dropout probability')
parser.add_argument('-lambda', type=float, default=1.0, help='Scaling factor')
parser.add_argument('--dry', action='store_true', help='Dry run mode')
parser.add_argument('--token', type=str, help='HuggingFace token')
parser.add_argument('--repo', type=str, help='HuggingFace repo to upload to')
parser.add_argument('--commit-message', type=str, default='Upload merged model', help='Commit message for model upload')
args = parser.parse_args()
staging_path = os.path.abspath(args.staging)
os.makedirs(staging_path, exist_ok=True)
base_model_name, base_model_path = "base_model", os.path.join(staging_path, "base_model")
staging_model_path = os.path.join(staging_path, "staging_model")
model_merger = ModelMerger(staging_path, args.repo, args.token, args.dry)
model_merger.prepare_base_model(base_model_name, base_model_path)
for repo_name in tqdm(args.repos[1:], desc="Merging Repos"):
model_merger.merge_repo(repo_name, staging_model_path, args.p, args.lambda)
model_merger.finalize_merge(args.output_dir)
max_vocab_size, _ = get_max_vocab_size(args.repos) # Unused variable removed
if args.repo:
model_merger.upload_model(args.output_dir, args.repo, args.commit_message)
if __name__ == "__main__":
main() |