import os import sys def fix_pytorch_int8(): valid_path = [p for p in sys.path if p and os.path.isdir(p)] for path in valid_path: for folder in os.listdir(path): if 'torch' in folder: packages_path = path break fix_path = f'{packages_path}/torch/nn/parameter.py' with open(fix_path, 'r') as f: text = f.read() if 'if data.dtype == torch.int8' not in text: text = text.replace( ' return torch.Tensor._make_subclass(cls, data, requires_grad)', ' if data.dtype == torch.int8:\n' \ ' requires_grad = False\n' \ ' return torch.Tensor._make_subclass(cls, data, requires_grad)' ) with open(fix_path, 'w') as f: f.write(text) return print('Fixed torch/nn/parameter.py')