OllieStanley commited on
Commit
32c97cb
·
1 Parent(s): 2009fe0
Files changed (2) hide show
  1. README.md +86 -0
  2. xor_codec.py +83 -0
README.md ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ ---
4
+
5
+ # OpenAssistant LLaMa-Based Models
6
+
7
+ Due to the license attached to LLaMa models by Meta AI it is not possible to directly distribute LLaMa-based models. Instead we provide XOR weights for the OA models.
8
+
9
+ Thanks to Mick for writing the `xor_codec.py` script which enables this process
10
+
11
+ ## The Process
12
+
13
+ Note: This process applies to `oasst-sft-6-llama-30b` model. The same process can be applied to other models in future, but the checksums will be different..
14
+
15
+ To use OpenAssistant LLaMa-Based Models, you need to have a copy of the original LLaMa model weights and add them to a `llama` subdirectory here.
16
+
17
+ Ensure your LLaMa 30B checkpoint matches the correct md5sums:
18
+
19
+ ```
20
+ f856e9d99c30855d6ead4d00cc3a5573 consolidated.00.pth
21
+ d9dbfbea61309dc1e087f5081e98331a consolidated.01.pth
22
+ 2b2bed47912ceb828c0a37aac4b99073 consolidated.02.pth
23
+ ea0405cdb5bc638fee12de614f729ebc consolidated.03.pth
24
+ 4babdbd05b8923226a9e9622492054b6 params.json
25
+ ```
26
+
27
+ These can be converted to HuggingFace Transformers-compatible weights using the script available [here](https://github.com/huggingface/transformers/blob/28f26c107b4a1c5c7e32ed4d9575622da0627a40/src/transformers/models/llama/convert_llama_weights_to_hf.py).
28
+
29
+ **Important**: It was tested with git version transformers 4.28.0.dev0 (git hash: **28f26c107b4a1c5c7e32ed4d9575622da0627a40**). Make sure the package tokenizers 0.13.3 is installed. Use of different versions may result in broken outputs.
30
+
31
+ ```
32
+ PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python python convert_llama_weights_to_hf.py --input_dir ~/llama/ --output_dir ~/llama30b_hf/ --model_size 30B
33
+ ```
34
+
35
+ Run `find -type f -exec md5sum "{}" + > checklist.chk` in the conversion target directory. This should produce a `checklist.chk` with exactly the following content if your files are correct:
36
+
37
+ ```
38
+ d0e13331c103453e9e087d59dcf05432 ./pytorch_model-00001-of-00007.bin
39
+ 29aae4d31a0a4fe6906353001341d493 ./pytorch_model-00002-of-00007.bin
40
+ b40838eb4e68e087b15b3d653ca1f5d7 ./pytorch_model-00003-of-00007.bin
41
+ f845ecc481cb92b8a0586c2ce288b828 ./pytorch_model-00004-of-00007.bin
42
+ f3b13d089840e6caf22cd6dd05b77ef0 ./pytorch_model-00005-of-00007.bin
43
+ 12e0d2d7a9c00c4237b1b0143c48a05e ./pytorch_model-00007-of-00007.bin
44
+ 1348f7c8bb3ee4408b69305a10bdfafb ./pytorch_model-00006-of-00007.bin
45
+ aee09e21813368c49baaece120125ae3 ./generation_config.json
46
+ eeec4125e9c7560836b4873b6f8e3025 ./tokenizer.model
47
+ 598538f18fed1877b41f77de034c0c8a ./config.json
48
+ fdb311c39b8659a5d5c1991339bafc09 ./tokenizer.json
49
+ b77e99aa2ddc3df500c2b2dc4455a6af ./pytorch_model.bin.index.json
50
+ edd1a5897748864768b1fab645b31491 ./tokenizer_config.json
51
+ 6b2e0a735969660e720c27061ef3f3d3 ./special_tokens_map.json
52
+ ```
53
+
54
+ Once you have LLaMa weights in the correct format, you can apply the XOR decoding:
55
+
56
+ ```
57
+ python xor_codec.py oasst-sft-6-llama-30b/ oasst-sft-6-llama-30b-xor/ llama30b_hf/
58
+ ```
59
+
60
+ You should expect to see one warning message during execution:
61
+
62
+ `Exception when processing 'added_tokens.json'`
63
+
64
+ This is normal. If similar messages appear for other files, something has gone wrong.
65
+
66
+ Now run `find -type f -exec md5sum "{}" + > checklist.chk` in the output directory (here `oasst-sft-6-llama-30b`). You should get a file with exactly these contents:
67
+
68
+ ```
69
+ 970e99665d66ba3fad6fdf9b4910acc5 ./pytorch_model-00007-of-00007.bin
70
+ 659fcb7598dcd22e7d008189ecb2bb42 ./pytorch_model-00003-of-00007.bin
71
+ ff6e4cf43ddf02fb5d3960f850af1220 ./pytorch_model-00001-of-00007.bin
72
+ 27b0dc092f99aa2efaf467b2d8026c3f ./added_tokens.json
73
+ aee09e21813368c49baaece120125ae3 ./generation_config.json
74
+ 740c324ae65b1ec25976643cda79e479 ./pytorch_model-00005-of-00007.bin
75
+ f7aefb4c63be2ac512fd905b45295235 ./pytorch_model-00004-of-00007.bin
76
+ eeec4125e9c7560836b4873b6f8e3025 ./tokenizer.model
77
+ 369df2f0e38bda0d9629a12a77c10dfc ./pytorch_model-00006-of-00007.bin
78
+ 27b9c7c8c62db80e92de14724f4950f3 ./config.json
79
+ deb33dd4ffc3d2baddcce275a00b7c1b ./tokenizer.json
80
+ 76d47e4f51a8df1d703c6f594981fcab ./pytorch_model.bin.index.json
81
+ ed59bfee4e87b9193fea5897d610ab24 ./tokenizer_config.json
82
+ 130f5e690becc2223f59384887c2a505 ./special_tokens_map.json
83
+ ae48c4c68e4e171d502dd0896aa19a84 ./pytorch_model-00002-of-00007.bin
84
+ ```
85
+
86
+ If so you have successfully decoded the weights and should be able to use the model with HuggingFace Transformers.
xor_codec.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import shutil
4
+ import gzip
5
+ import numpy
6
+ from pathlib import Path
7
+
8
+ def xor_uncompressed(dst, src_payload, src_base, block_size=4096):
9
+ fp_payload = open(src_payload, 'rb')
10
+ fp_base = open(src_base, 'rb')
11
+ with open(dst, 'wb') as fp:
12
+ while True:
13
+ buf1 = numpy.array(bytearray(fp_payload.read(block_size)), dtype=numpy.uint8)
14
+ buf2 = numpy.array(bytearray(fp_base.read(block_size)), dtype=numpy.uint8)
15
+ padding = len(buf1) - len(buf2)
16
+ if padding > 0: buf2 = numpy.pad(buf2, (0, padding), 'constant', constant_values=(0,))
17
+ if padding < 0: buf2 = buf2[:len(buf1)]
18
+ buf = numpy.bitwise_xor(buf1, buf2)
19
+ fp.write(buf)
20
+ if len(buf1) < block_size: break
21
+ fp_payload.close()
22
+ fp_base.close()
23
+
24
+ def xor_encode(dst, src_payload, src_base, block_size=4096):
25
+ fp_payload = open(src_payload, 'rb')
26
+ fp_base = open(src_base, 'rb')
27
+ with gzip.open(dst, 'wb') as fp:
28
+ while True:
29
+ buf1 = numpy.array(bytearray(fp_payload.read(block_size)), dtype=numpy.uint8)
30
+ buf2 = numpy.array(bytearray(fp_base.read(block_size)), dtype=numpy.uint8)
31
+ padding = len(buf1) - len(buf2)
32
+ if padding > 0: buf2 = numpy.pad(buf2, (0, padding), 'constant', constant_values=(0,))
33
+ if padding < 0: buf2 = buf2[:len(buf1)]
34
+ buf = numpy.bitwise_xor(buf1, buf2)
35
+ fp.write(buf)
36
+ if len(buf1) < block_size: break
37
+ fp_payload.close()
38
+ fp_base.close()
39
+
40
+ def xor_decode(dst, src_payload, src_base, block_size=4096):
41
+ fp_payload = gzip.open(src_payload, 'rb')
42
+ fp_base = open(src_base, 'rb')
43
+ with open(dst, 'wb') as fp:
44
+ while True:
45
+ buf1 = numpy.array(bytearray(fp_payload.read(block_size)), dtype=numpy.uint8)
46
+ buf2 = numpy.array(bytearray(fp_base.read(block_size)), dtype=numpy.uint8)
47
+ padding = len(buf1) - len(buf2)
48
+ if padding > 0: buf2 = numpy.pad(buf2, (0, padding), 'constant', constant_values=(0,))
49
+ if padding < 0: buf2 = buf2[:len(buf1)]
50
+ buf = numpy.bitwise_xor(buf1, buf2)
51
+ fp.write(buf)
52
+ if len(buf1) < block_size: break
53
+ fp_payload.close()
54
+ fp_base.close()
55
+
56
+ def xor_dir(dst, src_payload, src_base, decode=True, compress=True):
57
+ if compress:
58
+ xor = xor_decode if decode else xor_encode
59
+ else:
60
+ xor = xor_uncompressed
61
+ Path(dst).mkdir(parents=True, exist_ok=True)
62
+ shutil.copy(Path(src_payload) / "added_tokens.json", Path(dst) / "added_tokens.json")
63
+ for path in os.listdir(src_payload):
64
+ print("[*] Processing '%s'" % path)
65
+ try:
66
+ xor("%s/%s" % (dst, path), "%s/%s" % (src_payload, path), "%s/%s" % (src_base, path))
67
+ except Exception as e:
68
+ print("Exception when processing '%s'" % path)
69
+
70
+ if __name__ == "__main__":
71
+ if len(sys.argv) < 4:
72
+ print("Usage: xor.py <DESTINATION> <PAYLOAD SOURCE> <LLAMA SOURCE> [--encode] [--compress]")
73
+ exit()
74
+ dst = sys.argv[1]
75
+ src_payload = sys.argv[2]
76
+ src_base = sys.argv[3]
77
+ decode = True
78
+ compress = False
79
+ if len(sys.argv) > 4:
80
+ for arg in sys.argv[4:]:
81
+ if arg == "--encode": decode = False
82
+ if arg == "--compress": compress = True
83
+ xor_dir(dst, src_payload, src_base, decode=decode, compress=compress)