DARE-MERGE-SAFETENSORS / hf_merge.py
mrcuddle's picture
Update hf_merge.py
33e9d70 verified
raw
history blame
5.94 kB
import os
import shutil
import argparse
import requests
from tqdm import tqdm
from huggingface_hub import HfApi, Repository, hf_hub_download, upload_folder
from merge import merge_folder, map_tensors_to_files, copy_nontensor_files, save_tensor_map
class RepositoryManager:
def __init__(self, repo_id=None, token=None, dry_run=False):
self.repo_id = repo_id
self.token = token
self.dry_run = dry_run
self.api = HfApi(token=token) if token else HfApi()
def download_repo(self, repo_name, path):
if self.dry_run:
print(f"[DRY RUN] Downloading {repo_name} to {path}")
return
if not os.path.exists(path):
os.makedirs(path)
repo_files = self.api.list_repo_files(repo_name)
for file_path in tqdm(repo_files, desc=f"Downloading {repo_name}"):
file_url = f"https://huggingface.co/{repo_name}/resolve/main/{file_path}"
hf_hub_download(repo_id=repo_name, filename=file_path, cache_dir=path, local_dir=path)
def delete_repo(self, path):
if self.dry_run:
print(f"[DRY RUN] Deleting {path}")
else:
shutil.rmtree(path, ignore_errors=True)
print(f"Deleted {path}")
class ModelMerger:
def __init__(self, staging_path, repo_id=None, token=None, dry_run=False):
self.staging_path = staging_path
self.repo_id = repo_id
self.token = token
self.dry_run = dry_run
self.tensor_map = None
self.api = HfApi(token=token) if token else HfApi()
def prepare_base_model(self, base_model_name, base_model_path):
repo_manager = RepositoryManager(self.repo_id, self.token, self.dry_run)
repo_manager.download_repo(base_model_name, base_model_path)
self.tensor_map = map_tensors_to_files(base_model_path)
def merge_repo(self, repo_name, repo_path, p, lambda_val):
repo_manager = RepositoryManager(self.repo_id, self.token, self.dry_run)
repo_manager.delete_repo(repo_path)
repo_manager.download_repo(repo_name, repo_path)
if self.dry_run:
print(f"[DRY RUN] Merging {repo_name} with p={p} and lambda={lambda_val}")
return
try:
self.tensor_map = merge_folder(self.tensor_map, repo_path, p, lambda_val)
print(f"Merged {repo_name}")
except Exception as e:
print(f"Error merging {repo_name}: {e}")
def finalize_merge(self, output_dir):
copy_nontensor_files(self.staging_path / 'base_model', output_dir)
save_tensor_map(self.tensor_map, output_dir)
def upload_model(self, output_dir, repo_name, commit_message):
if self.dry_run:
print(f"[DRY RUN] Uploading model to {repo_name}")
return
repo = Repository(repo_id=self.repo_id, token=self.token)
repo.create_branch("main", "main") # Ensure main branch exists
repo.upload_folder(output_dir, repo_path=repo_name, commit_message=commit_message)
print(f"Model uploaded to {repo_name}")
def get_max_vocab_size(repo_list):
max_vocab_size = 0
repo_with_max_vocab = None
base_url = "https://huggingface.co/{}/raw/main/config.json"
for repo_name, _, _ in repo_list:
url = base_url.format(repo_name)
try:
response = requests.get(url)
config = response.json()
vocab_size = config.get('vocab_size', 0)
if vocab_size > max_vocab_size:
max_vocab_size = vocab_size
repo_with_max_vocab = repo_name
except requests.RequestException as e:
print(f"Error fetching vocab size from {repo_name}: {e}")
return max_vocab_size, repo_with_max_vocab
def download_json_files(repo_name, file_paths, output_dir):
base_url = f"https://huggingface.co/{repo_name}/raw/main/"
for file_path in file_paths:
url = base_url + file_path
response = requests.get(url)
if response.status_code == 200:
with open(os.path.join(output_dir, os.path.basename(file_path)), 'wb') as file:
file.write(response.content)
else:
print(f"Failed to download {file_path} from {repo_name}")
def main():
parser = argparse.ArgumentParser(description="Merge and upload HuggingFace models")
parser.add_argument('repos', nargs='+', help='Repositories to merge')
parser.add_argument('output_dir', help='Output directory')
parser.add_argument('-staging', default='./staging', help='Staging folder')
parser.add_argument('-p', type=float, default=0.5, help='Dropout probability')
parser.add_argument('-lambda', type=float, default=1.0, help='Scaling factor')
parser.add_argument('--dry', action='store_true', help='Dry run mode')
parser.add_argument('--token', type=str, help='HuggingFace token')
parser.add_argument('--repo', type=str, help='HuggingFace repo to upload to')
parser.add_argument('--commit-message', type=str, default='Upload merged model', help='Commit message for model upload')
args = parser.parse_args()
staging_path = os.path.abspath(args.staging)
os.makedirs(staging_path, exist_ok=True)
base_model_name, base_model_path = "base_model", os.path.join(staging_path, "base_model")
staging_model_path = os.path.join(staging_path, "staging_model")
model_merger = ModelMerger(staging_path, args.repo, args.token, args.dry)
model_merger.prepare_base_model(base_model_name, base_model_path)
for repo_name in tqdm(args.repos[1:], desc="Merging Repos"):
model_merger.merge_repo(repo_name, staging_model_path, args.p, args.lambda)
model_merger.finalize_merge(args.output_dir)
max_vocab_size, _ = get_max_vocab_size(args.repos) # Unused variable removed
if args.repo:
model_merger.upload_model(args.output_dir, args.repo, args.commit_message)
if __name__ == "__main__":
main()