leafspark commited on
Commit
c6d7757
·
verified ·
1 Parent(s): 0399d2b

model: update merge script config

Browse files
Files changed (1) hide show
  1. merge.py +13 -8
merge.py CHANGED
@@ -22,8 +22,8 @@ def create_merge_plan(tensor_locations, layer_config):
22
  # Special handling for specific weights
23
  special_weights = {
24
  "model.embed_tokens.weight": 1,
25
- "lm_head.weight": 48,
26
- "model.norm.weight": 48
27
  }
28
 
29
  for slice_config in layer_config:
@@ -65,13 +65,18 @@ def create_merge_plan(tensor_locations, layer_config):
65
 
66
  return merge_plan
67
 
68
- def merge_layers(input_dir, output_dir, merge_plan):
69
  output_tensors = {}
70
- current_new_file_index = 1
71
  max_file_index = max(item['new_file_index'] for item in merge_plan)
72
 
73
  with tqdm(total=len(merge_plan), desc="Merging layers") as pbar:
74
- for file_index in range(1, max_file_index + 1):
 
 
 
 
 
 
75
  for item in merge_plan:
76
  if item['new_file_index'] == file_index:
77
  input_file = os.path.join(input_dir, f"model-{item['original_file_index']:05d}-of-00051.safetensors")
@@ -81,7 +86,6 @@ def merge_layers(input_dir, output_dir, merge_plan):
81
  pbar.update(1)
82
 
83
  if output_tensors:
84
- output_file = os.path.join(output_dir, f"model-{file_index:05d}-of-{max_file_index:05d}.safetensors")
85
  save_file(output_tensors, output_file)
86
  output_tensors = {}
87
 
@@ -92,6 +96,7 @@ def main():
92
  parser.add_argument("input_dir", help="Directory containing input safetensors files")
93
  parser.add_argument("output_dir", help="Directory for output safetensors files")
94
  parser.add_argument("--dry-run", action="store_true", help="Perform a dry run and output merge plan")
 
95
  args = parser.parse_args()
96
 
97
  layer_config = [
@@ -116,8 +121,8 @@ def main():
116
  print("Merge plan saved to merge_plan.json")
117
  else:
118
  os.makedirs(args.output_dir, exist_ok=True)
119
- merge_layers(args.input_dir, args.output_dir, merge_plan)
120
  print(f"Merged model saved to {args.output_dir}")
121
 
122
  if __name__ == "__main__":
123
- main()
 
22
  # Special handling for specific weights
23
  special_weights = {
24
  "model.embed_tokens.weight": 1,
25
+ "lm_head.weight": 156,
26
+ "model.norm.weight": 156
27
  }
28
 
29
  for slice_config in layer_config:
 
65
 
66
  return merge_plan
67
 
68
+ def merge_layers(input_dir, output_dir, merge_plan, start_file_index=1):
69
  output_tensors = {}
 
70
  max_file_index = max(item['new_file_index'] for item in merge_plan)
71
 
72
  with tqdm(total=len(merge_plan), desc="Merging layers") as pbar:
73
+ for file_index in range(start_file_index, max_file_index + 1):
74
+ output_file = os.path.join(output_dir, f"model-{file_index:05d}-of-{max_file_index:05d}.safetensors")
75
+
76
+ if os.path.exists(output_file):
77
+ pbar.update(sum(1 for item in merge_plan if item['new_file_index'] == file_index))
78
+ continue
79
+
80
  for item in merge_plan:
81
  if item['new_file_index'] == file_index:
82
  input_file = os.path.join(input_dir, f"model-{item['original_file_index']:05d}-of-00051.safetensors")
 
86
  pbar.update(1)
87
 
88
  if output_tensors:
 
89
  save_file(output_tensors, output_file)
90
  output_tensors = {}
91
 
 
96
  parser.add_argument("input_dir", help="Directory containing input safetensors files")
97
  parser.add_argument("output_dir", help="Directory for output safetensors files")
98
  parser.add_argument("--dry-run", action="store_true", help="Perform a dry run and output merge plan")
99
+ parser.add_argument("--continue-from", type=int, default=1, help="Continue merging from this file index")
100
  args = parser.parse_args()
101
 
102
  layer_config = [
 
121
  print("Merge plan saved to merge_plan.json")
122
  else:
123
  os.makedirs(args.output_dir, exist_ok=True)
124
+ merge_layers(args.input_dir, args.output_dir, merge_plan, start_file_index=args.continue_from)
125
  print(f"Merged model saved to {args.output_dir}")
126
 
127
  if __name__ == "__main__":
128
+ main()