Support device_map=sequential & max_memory config parameters (#903)
Browse files* Support device_map sequential (and others). Support max_memory in cfg.
* Update documentation in README accordingly.
* Update README.md
---------
Co-authored-by: Wing Lian <[email protected]>
- README.md +6 -0
- src/axolotl/utils/config.py +1 -1
- src/axolotl/utils/models.py +1 -0
README.md
CHANGED
@@ -612,6 +612,12 @@ eval_sample_packing:
|
|
612 |
sample_packing_eff_est:
|
613 |
total_num_tokens:
|
614 |
|
|
|
|
|
|
|
|
|
|
|
|
|
615 |
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
616 |
adapter: lora
|
617 |
# If you already have a lora model trained that you want to load, put that here.
|
|
|
612 |
sample_packing_eff_est:
|
613 |
total_num_tokens:
|
614 |
|
615 |
+
# Passed through to transformers when loading the model when launched without accelerate
|
616 |
+
# Use `sequential` when training w/ model parallelism to limit memory
|
617 |
+
device_map:
|
618 |
+
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
|
619 |
+
max_memory:
|
620 |
+
|
621 |
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
622 |
adapter: lora
|
623 |
# If you already have a lora model trained that you want to load, put that here.
|
src/axolotl/utils/config.py
CHANGED
@@ -27,7 +27,7 @@ def choose_device(cfg):
|
|
27 |
|
28 |
cfg.device = get_device()
|
29 |
if cfg.world_size == 1:
|
30 |
-
cfg.device_map = "auto"
|
31 |
else:
|
32 |
if cfg.device.startswith("cuda"):
|
33 |
cfg.device_map = {"": torch.cuda.current_device()}
|
|
|
27 |
|
28 |
cfg.device = get_device()
|
29 |
if cfg.world_size == 1:
|
30 |
+
cfg.device_map = cfg.device_map or "auto"
|
31 |
else:
|
32 |
if cfg.device.startswith("cuda"):
|
33 |
cfg.device_map = {"": torch.cuda.current_device()}
|
src/axolotl/utils/models.py
CHANGED
@@ -216,6 +216,7 @@ def load_model(
|
|
216 |
model_kwargs = {}
|
217 |
|
218 |
model_kwargs["device_map"] = cfg.device_map
|
|
|
219 |
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
220 |
|
221 |
if cfg.model_revision:
|
|
|
216 |
model_kwargs = {}
|
217 |
|
218 |
model_kwargs["device_map"] = cfg.device_map
|
219 |
+
model_kwargs["max_memory"] = cfg.max_memory
|
220 |
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
221 |
|
222 |
if cfg.model_revision:
|