msr2000 commited on
Commit
9672b38
1 Parent(s): 9ffbf9c

Fix fp8_cast_bf16.py: https://github.com/deepseek-ai/DeepSeek-V3/commit/8f1c9488b53068992f9525fab03b1868e6f7c8c1

Browse files
Files changed (1) hide show
  1. inference/fp8_cast_bf16.py +37 -11
inference/fp8_cast_bf16.py CHANGED
@@ -16,32 +16,58 @@ def main(fp8_path, bf16_path):
16
  with open(model_index_file, "r") as f:
17
  model_index = json.load(f)
18
  weight_map = model_index["weight_map"]
19
- fp8_weight_names = []
20
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
 
22
  for safetensor_file in tqdm(safetensor_files):
23
  file_name = os.path.basename(safetensor_file)
24
- state_dict = load_file(safetensor_file, device="cuda")
 
 
25
  new_state_dict = {}
26
- for weight_name, weight in state_dict.items():
27
  if weight_name.endswith("_scale_inv"):
28
  continue
29
- elif weight.element_size() == 1:
30
  scale_inv_name = f"{weight_name}_scale_inv"
31
- assert scale_inv_name in state_dict
32
- fp8_weight_names.append(weight_name)
33
- scale_inv = state_dict[scale_inv_name]
34
- new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
 
 
 
 
35
  else:
36
  new_state_dict[weight_name] = weight
 
37
  new_safetensor_file = os.path.join(bf16_path, file_name)
38
  save_file(new_state_dict, new_safetensor_file)
 
 
 
 
 
 
39
 
 
40
  new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
41
  for weight_name in fp8_weight_names:
42
  scale_inv_name = f"{weight_name}_scale_inv"
43
- assert scale_inv_name in weight_map
44
- weight_map.pop(scale_inv_name)
45
  with open(new_model_index_file, "w") as f:
46
  json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
47
 
@@ -52,4 +78,4 @@ if __name__ == "__main__":
52
  parser.add_argument("--output-bf16-hf-path", type=str, required=True)
53
  args = parser.parse_args()
54
  main(args.input_fp8_hf_path, args.output_bf16_hf_path)
55
-
 
16
  with open(model_index_file, "r") as f:
17
  model_index = json.load(f)
18
  weight_map = model_index["weight_map"]
 
19
 
20
+ # Cache for loaded safetensor files
21
+ loaded_files = {}
22
+ fp8_weight_names = []
23
+
24
+ # Helper function to get tensor from the correct file
25
+ def get_tensor(tensor_name):
26
+ file_name = weight_map[tensor_name]
27
+ if file_name not in loaded_files:
28
+ file_path = os.path.join(fp8_path, file_name)
29
+ loaded_files[file_name] = load_file(file_path, device="cuda")
30
+ return loaded_files[file_name][tensor_name]
31
+
32
  safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
33
+ safetensor_files.sort()
34
  for safetensor_file in tqdm(safetensor_files):
35
  file_name = os.path.basename(safetensor_file)
36
+ current_state_dict = load_file(safetensor_file, device="cuda")
37
+ loaded_files[file_name] = current_state_dict
38
+
39
  new_state_dict = {}
40
+ for weight_name, weight in current_state_dict.items():
41
  if weight_name.endswith("_scale_inv"):
42
  continue
43
+ elif weight.element_size() == 1: # FP8 weight
44
  scale_inv_name = f"{weight_name}_scale_inv"
45
+ try:
46
+ # Get scale_inv from the correct file
47
+ scale_inv = get_tensor(scale_inv_name)
48
+ fp8_weight_names.append(weight_name)
49
+ new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
50
+ except KeyError:
51
+ print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
52
+ new_state_dict[weight_name] = weight
53
  else:
54
  new_state_dict[weight_name] = weight
55
+
56
  new_safetensor_file = os.path.join(bf16_path, file_name)
57
  save_file(new_state_dict, new_safetensor_file)
58
+
59
+ # Memory management: keep only the 2 most recently used files
60
+ if len(loaded_files) > 2:
61
+ oldest_file = next(iter(loaded_files))
62
+ del loaded_files[oldest_file]
63
+ torch.cuda.empty_cache()
64
 
65
+ # Update model index
66
  new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
67
  for weight_name in fp8_weight_names:
68
  scale_inv_name = f"{weight_name}_scale_inv"
69
+ if scale_inv_name in weight_map:
70
+ weight_map.pop(scale_inv_name)
71
  with open(new_model_index_file, "w") as f:
72
  json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
73
 
 
78
  parser.add_argument("--output-bf16-hf-path", type=str, required=True)
79
  args = parser.parse_args()
80
  main(args.input_fp8_hf_path, args.output_bf16_hf_path)
81
+