Narsil HF staff commited on
Commit
0dee7c2
·
1 Parent(s): a3c5547

Update convert.py

Browse files
Files changed (1) hide show
  1. convert.py +6 -6
convert.py CHANGED
@@ -45,7 +45,7 @@ def rename(pt_filename: str) -> str:
45
  return local
46
 
47
 
48
- def convert_multi(model_id: str, folder: str) -> ConversionResult:
49
  filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json")
50
  with open(filename, "r") as f:
51
  data = json.load(f)
@@ -87,7 +87,7 @@ def convert_multi(model_id: str, folder: str) -> ConversionResult:
87
 
88
 
89
 
90
- def convert_single(model_id: str, folder: str) -> ConversionResult:
91
  pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin")
92
 
93
  sf_name = "model.safetensors"
@@ -157,7 +157,7 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discuss
157
  return None
158
 
159
 
160
- def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> ConversionResult:
161
  operations = []
162
  errors = []
163
 
@@ -224,13 +224,13 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn
224
  raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
225
  elif library_name == "transformers":
226
  if "pytorch_model.bin" in filenames:
227
- new_pr = convert_single(model_id, folder)
228
  elif "pytorch_model.bin.index.json" in filenames:
229
- new_pr = convert_multi(model_id, folder)
230
  else:
231
  raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
232
  else:
233
- new_pr = convert_generic(model_id, folder, filenames)
234
 
235
  print(f"Pr created at {new_pr.pr_url}")
236
  finally:
 
45
  return local
46
 
47
 
48
+ def convert_multi(model_id: str, folder: str, api: "HfApi") -> ConversionResult:
49
  filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json")
50
  with open(filename, "r") as f:
51
  data = json.load(f)
 
87
 
88
 
89
 
90
+ def convert_single(model_id: str, folder: str, api: "HfApi") -> ConversionResult:
91
  pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin")
92
 
93
  sf_name = "model.safetensors"
 
157
  return None
158
 
159
 
160
+ def convert_generic(model_id: str, folder: str, filenames: Set[str], api: "HfApi") -> ConversionResult:
161
  operations = []
162
  errors = []
163
 
 
224
  raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
225
  elif library_name == "transformers":
226
  if "pytorch_model.bin" in filenames:
227
+ new_pr = convert_single(model_id, folder, api)
228
  elif "pytorch_model.bin.index.json" in filenames:
229
+ new_pr = convert_multi(model_id, folder, api)
230
  else:
231
  raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
232
  else:
233
+ new_pr = convert_generic(model_id, folder, filenames, api)
234
 
235
  print(f"Pr created at {new_pr.pr_url}")
236
  finally: