try to detect accelerate and only use device_map=None in that case (#373)
Browse files
src/axolotl/utils/config.py
CHANGED
@@ -30,6 +30,12 @@ def choose_device(cfg):
|
|
30 |
else:
|
31 |
cfg.device_map = {"": cfg.device}
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
def normalize_config(cfg):
|
35 |
# setup some derived config / hyperparams
|
|
|
30 |
else:
|
31 |
cfg.device_map = {"": cfg.device}
|
32 |
|
33 |
+
# in `accelerate launch`, we need to not pass through any device map and let
|
34 |
+
# accelerate figure out which parts of the model to put on which gpu
|
35 |
+
accelerate_vars = [var for var in os.environ if var.startswith("ACCELERATE_USE_")]
|
36 |
+
if accelerate_vars:
|
37 |
+
cfg.device_map = None
|
38 |
+
|
39 |
|
40 |
def normalize_config(cfg):
|
41 |
# setup some derived config / hyperparams
|
src/axolotl/utils/models.py
CHANGED
@@ -235,6 +235,7 @@ def load_model(
|
|
235 |
model = LlamaForCausalLM.from_pretrained(
|
236 |
base_model,
|
237 |
config=config,
|
|
|
238 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
239 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
240 |
torch_dtype=torch_dtype,
|
@@ -269,6 +270,7 @@ def load_model(
|
|
269 |
elif model_type and not cfg.trust_remote_code:
|
270 |
model = getattr(transformers, model_type).from_pretrained(
|
271 |
base_model,
|
|
|
272 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
273 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
274 |
torch_dtype=torch_dtype,
|
@@ -299,6 +301,7 @@ def load_model(
|
|
299 |
model = AutoModelForCausalLM.from_pretrained(
|
300 |
base_model,
|
301 |
config=config,
|
|
|
302 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
303 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
304 |
torch_dtype=torch_dtype,
|
@@ -312,6 +315,7 @@ def load_model(
|
|
312 |
LOG.exception(err)
|
313 |
model = AutoModelForCausalLM.from_pretrained(
|
314 |
base_model,
|
|
|
315 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
316 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
317 |
torch_dtype=torch_dtype,
|
|
|
235 |
model = LlamaForCausalLM.from_pretrained(
|
236 |
base_model,
|
237 |
config=config,
|
238 |
+
device_map=cfg.device_map,
|
239 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
240 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
241 |
torch_dtype=torch_dtype,
|
|
|
270 |
elif model_type and not cfg.trust_remote_code:
|
271 |
model = getattr(transformers, model_type).from_pretrained(
|
272 |
base_model,
|
273 |
+
device_map=cfg.device_map,
|
274 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
275 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
276 |
torch_dtype=torch_dtype,
|
|
|
301 |
model = AutoModelForCausalLM.from_pretrained(
|
302 |
base_model,
|
303 |
config=config,
|
304 |
+
device_map=cfg.device_map,
|
305 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
306 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
307 |
torch_dtype=torch_dtype,
|
|
|
315 |
LOG.exception(err)
|
316 |
model = AutoModelForCausalLM.from_pretrained(
|
317 |
base_model,
|
318 |
+
device_map=cfg.device_map,
|
319 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
320 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
321 |
torch_dtype=torch_dtype,
|