tmm1 commited on
Commit
094fc2c
·
unverified ·
1 Parent(s): 2dafa73

try to detect accelerate and only use device_map=None in that case (#373)

Browse files
src/axolotl/utils/config.py CHANGED
@@ -30,6 +30,12 @@ def choose_device(cfg):
30
  else:
31
  cfg.device_map = {"": cfg.device}
32
 
 
 
 
 
 
 
33
 
34
  def normalize_config(cfg):
35
  # setup some derived config / hyperparams
 
30
  else:
31
  cfg.device_map = {"": cfg.device}
32
 
33
+ # in `accelerate launch`, we need to not pass through any device map and let
34
+ # accelerate figure out which parts of the model to put on which gpu
35
+ accelerate_vars = [var for var in os.environ if var.startswith("ACCELERATE_USE_")]
36
+ if accelerate_vars:
37
+ cfg.device_map = None
38
+
39
 
40
  def normalize_config(cfg):
41
  # setup some derived config / hyperparams
src/axolotl/utils/models.py CHANGED
@@ -235,6 +235,7 @@ def load_model(
235
  model = LlamaForCausalLM.from_pretrained(
236
  base_model,
237
  config=config,
 
238
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
239
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
240
  torch_dtype=torch_dtype,
@@ -269,6 +270,7 @@ def load_model(
269
  elif model_type and not cfg.trust_remote_code:
270
  model = getattr(transformers, model_type).from_pretrained(
271
  base_model,
 
272
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
273
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
274
  torch_dtype=torch_dtype,
@@ -299,6 +301,7 @@ def load_model(
299
  model = AutoModelForCausalLM.from_pretrained(
300
  base_model,
301
  config=config,
 
302
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
303
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
304
  torch_dtype=torch_dtype,
@@ -312,6 +315,7 @@ def load_model(
312
  LOG.exception(err)
313
  model = AutoModelForCausalLM.from_pretrained(
314
  base_model,
 
315
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
316
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
317
  torch_dtype=torch_dtype,
 
235
  model = LlamaForCausalLM.from_pretrained(
236
  base_model,
237
  config=config,
238
+ device_map=cfg.device_map,
239
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
240
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
241
  torch_dtype=torch_dtype,
 
270
  elif model_type and not cfg.trust_remote_code:
271
  model = getattr(transformers, model_type).from_pretrained(
272
  base_model,
273
+ device_map=cfg.device_map,
274
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
275
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
276
  torch_dtype=torch_dtype,
 
301
  model = AutoModelForCausalLM.from_pretrained(
302
  base_model,
303
  config=config,
304
+ device_map=cfg.device_map,
305
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
306
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
307
  torch_dtype=torch_dtype,
 
315
  LOG.exception(err)
316
  model = AutoModelForCausalLM.from_pretrained(
317
  base_model,
318
+ device_map=cfg.device_map,
319
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
320
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
321
  torch_dtype=torch_dtype,