model: update merge script config
Browse files
merge.py
CHANGED
@@ -22,8 +22,8 @@ def create_merge_plan(tensor_locations, layer_config):
|
|
22 |
# Special handling for specific weights
|
23 |
special_weights = {
|
24 |
"model.embed_tokens.weight": 1,
|
25 |
-
"lm_head.weight":
|
26 |
-
"model.norm.weight":
|
27 |
}
|
28 |
|
29 |
for slice_config in layer_config:
|
@@ -65,13 +65,18 @@ def create_merge_plan(tensor_locations, layer_config):
|
|
65 |
|
66 |
return merge_plan
|
67 |
|
68 |
-
def merge_layers(input_dir, output_dir, merge_plan):
|
69 |
output_tensors = {}
|
70 |
-
current_new_file_index = 1
|
71 |
max_file_index = max(item['new_file_index'] for item in merge_plan)
|
72 |
|
73 |
with tqdm(total=len(merge_plan), desc="Merging layers") as pbar:
|
74 |
-
for file_index in range(
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
for item in merge_plan:
|
76 |
if item['new_file_index'] == file_index:
|
77 |
input_file = os.path.join(input_dir, f"model-{item['original_file_index']:05d}-of-00051.safetensors")
|
@@ -81,7 +86,6 @@ def merge_layers(input_dir, output_dir, merge_plan):
|
|
81 |
pbar.update(1)
|
82 |
|
83 |
if output_tensors:
|
84 |
-
output_file = os.path.join(output_dir, f"model-{file_index:05d}-of-{max_file_index:05d}.safetensors")
|
85 |
save_file(output_tensors, output_file)
|
86 |
output_tensors = {}
|
87 |
|
@@ -92,6 +96,7 @@ def main():
|
|
92 |
parser.add_argument("input_dir", help="Directory containing input safetensors files")
|
93 |
parser.add_argument("output_dir", help="Directory for output safetensors files")
|
94 |
parser.add_argument("--dry-run", action="store_true", help="Perform a dry run and output merge plan")
|
|
|
95 |
args = parser.parse_args()
|
96 |
|
97 |
layer_config = [
|
@@ -116,8 +121,8 @@ def main():
|
|
116 |
print("Merge plan saved to merge_plan.json")
|
117 |
else:
|
118 |
os.makedirs(args.output_dir, exist_ok=True)
|
119 |
-
merge_layers(args.input_dir, args.output_dir, merge_plan)
|
120 |
print(f"Merged model saved to {args.output_dir}")
|
121 |
|
122 |
if __name__ == "__main__":
|
123 |
-
main()
|
|
|
22 |
# Special handling for specific weights
|
23 |
special_weights = {
|
24 |
"model.embed_tokens.weight": 1,
|
25 |
+
"lm_head.weight": 156,
|
26 |
+
"model.norm.weight": 156
|
27 |
}
|
28 |
|
29 |
for slice_config in layer_config:
|
|
|
65 |
|
66 |
return merge_plan
|
67 |
|
68 |
+
def merge_layers(input_dir, output_dir, merge_plan, start_file_index=1):
|
69 |
output_tensors = {}
|
|
|
70 |
max_file_index = max(item['new_file_index'] for item in merge_plan)
|
71 |
|
72 |
with tqdm(total=len(merge_plan), desc="Merging layers") as pbar:
|
73 |
+
for file_index in range(start_file_index, max_file_index + 1):
|
74 |
+
output_file = os.path.join(output_dir, f"model-{file_index:05d}-of-{max_file_index:05d}.safetensors")
|
75 |
+
|
76 |
+
if os.path.exists(output_file):
|
77 |
+
pbar.update(sum(1 for item in merge_plan if item['new_file_index'] == file_index))
|
78 |
+
continue
|
79 |
+
|
80 |
for item in merge_plan:
|
81 |
if item['new_file_index'] == file_index:
|
82 |
input_file = os.path.join(input_dir, f"model-{item['original_file_index']:05d}-of-00051.safetensors")
|
|
|
86 |
pbar.update(1)
|
87 |
|
88 |
if output_tensors:
|
|
|
89 |
save_file(output_tensors, output_file)
|
90 |
output_tensors = {}
|
91 |
|
|
|
96 |
parser.add_argument("input_dir", help="Directory containing input safetensors files")
|
97 |
parser.add_argument("output_dir", help="Directory for output safetensors files")
|
98 |
parser.add_argument("--dry-run", action="store_true", help="Perform a dry run and output merge plan")
|
99 |
+
parser.add_argument("--continue-from", type=int, default=1, help="Continue merging from this file index")
|
100 |
args = parser.parse_args()
|
101 |
|
102 |
layer_config = [
|
|
|
121 |
print("Merge plan saved to merge_plan.json")
|
122 |
else:
|
123 |
os.makedirs(args.output_dir, exist_ok=True)
|
124 |
+
merge_layers(args.input_dir, args.output_dir, merge_plan, start_file_index=args.continue_from)
|
125 |
print(f"Merged model saved to {args.output_dir}")
|
126 |
|
127 |
if __name__ == "__main__":
|
128 |
+
main()
|