zetavg commited on
Commit
4b2400e
·
unverified ·
1 Parent(s): 116804a

fix “RuntimeError: expected scalar type Half but found Float” on lambdalabs and hf

Browse files
Files changed (1) hide show
  1. llama_lora/models.py +15 -8
llama_lora/models.py CHANGED
@@ -35,26 +35,38 @@ def get_model_with_lora(lora_weights: str = "tloen/alpaca-lora-7b"):
35
  Global.model_has_been_used = True
36
 
37
  if device == "cuda":
38
- return PeftModel.from_pretrained(
39
  get_base_model(),
40
  lora_weights,
41
  torch_dtype=torch.float16,
42
  device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
43
  )
44
  elif device == "mps":
45
- return PeftModel.from_pretrained(
46
  get_base_model(),
47
  lora_weights,
48
  device_map={"": device},
49
  torch_dtype=torch.float16,
50
  )
51
  else:
52
- return PeftModel.from_pretrained(
53
  get_base_model(),
54
  lora_weights,
55
  device_map={"": device},
56
  )
57
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def get_tokenizer():
60
  load_base_model()
@@ -89,11 +101,6 @@ def load_base_model():
89
  base_model, device_map={"": device}, low_cpu_mem_usage=True
90
  )
91
 
92
- # unwind broken decapoda-research config
93
- Global.loaded_base_model.config.pad_token_id = Global.loaded_tokenizer.pad_token_id = 0 # unk
94
- Global.loaded_base_model.config.bos_token_id = 1
95
- Global.loaded_base_model.config.eos_token_id = 2
96
-
97
 
98
  def unload_models():
99
  del Global.loaded_base_model
 
35
  Global.model_has_been_used = True
36
 
37
  if device == "cuda":
38
+ model = PeftModel.from_pretrained(
39
  get_base_model(),
40
  lora_weights,
41
  torch_dtype=torch.float16,
42
  device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
43
  )
44
  elif device == "mps":
45
+ model = PeftModel.from_pretrained(
46
  get_base_model(),
47
  lora_weights,
48
  device_map={"": device},
49
  torch_dtype=torch.float16,
50
  )
51
  else:
52
+ model = PeftModel.from_pretrained(
53
  get_base_model(),
54
  lora_weights,
55
  device_map={"": device},
56
  )
57
 
58
+ model.config.pad_token_id = get_tokenizer().pad_token_id = 0
59
+ model.config.bos_token_id = 1
60
+ model.config.eos_token_id = 2
61
+
62
+ if not Global.load_8bit:
63
+ model.half() # seems to fix bugs for some users.
64
+
65
+ model.eval()
66
+ if torch.__version__ >= "2" and sys.platform != "win32":
67
+ model = torch.compile(model)
68
+ return model
69
+
70
 
71
  def get_tokenizer():
72
  load_base_model()
 
101
  base_model, device_map={"": device}, low_cpu_mem_usage=True
102
  )
103
 
 
 
 
 
 
104
 
105
  def unload_models():
106
  del Global.loaded_base_model