mgoin commited on
Commit
66b4234
·
verified ·
1 Parent(s): 2b587ed

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +83 -2
README.md CHANGED
@@ -5,9 +5,90 @@ license_link: >-
5
  https://developer.download.nvidia.com/licenses/nvidia-open-model-license-agreement-june-2024.pdf
6
  ---
7
 
8
- # NOTICE; PLEASE READ. NO INFERENCE. (YET)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- **This has no support for inference, yet.** All I've done is move the weights out of NVIDIAs NeMo architecture so people smarter than me can get a headstart on making it work with other backends.
11
 
12
  ## Nemotron-4-340B-Instruct
13
 
 
5
  https://developer.download.nvidia.com/licenses/nvidia-open-model-license-agreement-june-2024.pdf
6
  ---
7
 
8
+ FP8 weight-only quantized checkpoint of https://huggingface.co/mgoin/Nemotron-4-340B-Instruct-vllm. For use with https://github.com/vllm-project/vllm/pull/6611
9
+
10
+
11
+ This script was used for the creation of this model, in addition to adding the quantization config to config.json:
12
+ ```python
13
+ import argparse
14
+ import os
15
+ import json
16
+ import torch
17
+ import safetensors.torch
18
+
19
+ def per_tensor_quantize(tensor):
20
+ """Quantize a tensor to FP8 using per-tensor static scaling factor."""
21
+ finfo = torch.finfo(torch.float8_e4m3fn)
22
+ if tensor.numel() == 0:
23
+ min_val, max_val = torch.tensor(-16.0, dtype=tensor.dtype), torch.tensor(16.0, dtype=tensor.dtype)
24
+ else:
25
+ min_val, max_val = tensor.aminmax()
26
+ amax = torch.maximum(min_val.abs(), max_val.abs())
27
+ scale = finfo.max / amax.clamp(min=1e-12)
28
+ qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max).to(torch.float8_e4m3fn)
29
+ scale = scale.float().reciprocal()
30
+ return qweight, scale
31
+
32
+ def process_safetensors_file(file_path):
33
+ """Process a single safetensors file in-place, quantizing weights to FP8."""
34
+ print(f"Processing {file_path}")
35
+ tensors = safetensors.torch.load_file(file_path)
36
+
37
+ modified_tensors = {}
38
+ for name, tensor in tensors.items():
39
+ if name.endswith('_proj.weight'):
40
+ print("Quantizing", name)
41
+ qweight, scale = per_tensor_quantize(tensor)
42
+ modified_tensors[name] = qweight
43
+ modified_tensors[f"{name}_scale"] = scale
44
+ else:
45
+ modified_tensors[name] = tensor
46
+
47
+ safetensors.torch.save_file(modified_tensors, file_path)
48
+ print(f"Updated {file_path} with quantized tensors")
49
+
50
+ def update_index_file(index_file_path):
51
+ """Update the index file for the quantized model."""
52
+ print(f"Updating index file: {index_file_path}")
53
+ with open(index_file_path, 'r') as f:
54
+ index = json.load(f)
55
+
56
+ new_weight_map = {}
57
+ for tensor_name, file_name in index['weight_map'].items():
58
+ new_weight_map[tensor_name] = file_name
59
+ if tensor_name.endswith('_proj.weight'):
60
+ new_weight_map[f"{tensor_name}_scale"] = file_name
61
+
62
+ index['weight_map'] = new_weight_map
63
+
64
+ # Recalculate total_size
65
+ total_size = sum(os.path.getsize(os.path.join(os.path.dirname(index_file_path), file))
66
+ for file in set(index['weight_map'].values()))
67
+ index['metadata']['total_size'] = total_size
68
+
69
+ with open(index_file_path, 'w') as f:
70
+ json.dump(index, f, indent=2)
71
+ print(f"Updated index file {index_file_path}")
72
+
73
+ def process_directory(directory):
74
+ """Process all safetensors files in the given directory."""
75
+ for filename in os.listdir(directory):
76
+ file_path = os.path.join(directory, filename)
77
+ if filename.endswith('.safetensors'):
78
+ process_safetensors_file(file_path)
79
+ elif filename == 'model.safetensors.index.json':
80
+ index_file_path = file_path
81
+
82
+ update_index_file(index_file_path)
83
+
84
+ if __name__ == '__main__':
85
+ parser = argparse.ArgumentParser(description='Convert safetensors model to FP8 in-place.')
86
+ parser.add_argument('directory', type=str, help='The directory containing the safetensors files and index file.')
87
+
88
+ args = parser.parse_args()
89
+ process_directory(args.directory)
90
+ ```
91
 
 
92
 
93
  ## Nemotron-4-340B-Instruct
94