Spaces:
Runtime error
Runtime error
import os | |
import sys | |
def fix_pytorch_int8(): | |
valid_path = [p for p in sys.path if p and os.path.isdir(p)] | |
for path in valid_path: | |
for folder in os.listdir(path): | |
if 'torch' in folder: | |
packages_path = path | |
break | |
fix_path = f'{packages_path}/torch/nn/parameter.py' | |
with open(fix_path, 'r') as f: | |
text = f.read() | |
if 'if data.dtype == torch.int8' not in text: | |
text = text.replace( | |
' return torch.Tensor._make_subclass(cls, data, requires_grad)', | |
' if data.dtype == torch.int8:\n' \ | |
' requires_grad = False\n' \ | |
' return torch.Tensor._make_subclass(cls, data, requires_grad)' | |
) | |
with open(fix_path, 'w') as f: | |
f.write(text) | |
return print('Fixed torch/nn/parameter.py') | |