fix device map
Browse files- scripts/finetune.py +5 -4
scripts/finetune.py
CHANGED
@@ -47,10 +47,11 @@ def choose_device(cfg):
|
|
47 |
return "cpu"
|
48 |
|
49 |
cfg.device = get_device()
|
50 |
-
if cfg.
|
51 |
-
cfg.
|
52 |
-
|
53 |
-
|
|
|
54 |
|
55 |
|
56 |
def get_multi_line_input() -> Optional[str]:
|
|
|
47 |
return "cpu"
|
48 |
|
49 |
cfg.device = get_device()
|
50 |
+
if cfg.device_map != "auto":
|
51 |
+
if cfg.device.startswith("cuda"):
|
52 |
+
cfg.device_map = {"": cfg.local_rank}
|
53 |
+
else:
|
54 |
+
cfg.device_map = {"": cfg.device}
|
55 |
|
56 |
|
57 |
def get_multi_line_input() -> Optional[str]:
|