KumaGLM-Lite / fix_int8.py
KumaTea's picture
follow KumaTea/KumaGLM
221f925
raw
history blame contribute delete
890 Bytes
import os
import sys
def fix_pytorch_int8():
valid_path = [p for p in sys.path if p and os.path.isdir(p)]
for path in valid_path:
for folder in os.listdir(path):
if 'torch' in folder:
packages_path = path
break
fix_path = f'{packages_path}/torch/nn/parameter.py'
with open(fix_path, 'r') as f:
text = f.read()
if 'if data.dtype == torch.int8' not in text:
text = text.replace(
' return torch.Tensor._make_subclass(cls, data, requires_grad)',
' if data.dtype == torch.int8:\n' \
' requires_grad = False\n' \
' return torch.Tensor._make_subclass(cls, data, requires_grad)'
)
with open(fix_path, 'w') as f:
f.write(text)
return print('Fixed torch/nn/parameter.py')